1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "mlir/Transforms/FoldUtils.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 
20 using namespace mlir;
21 using namespace test;
22 
23 // Native function for testing NativeCodeCall
24 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
25   return choice.getValue() ? input1 : input2;
26 }
27 
28 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
29   rewriter.create<OpI>(loc, input);
30 }
31 
32 static void handleNoResultOp(PatternRewriter &rewriter,
33                              OpSymbolBindingNoResult op) {
34   // Turn the no result op to a one-result op.
35   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
36                                     op.getOperand());
37 }
38 
39 static bool getFirstI32Result(Operation *op, Value &value) {
40   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
41     return false;
42   value = op->getResult(0);
43   return true;
44 }
45 
46 static Value bindNativeCodeCallResult(Value value) { return value; }
47 
48 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
49                                                               Value input2) {
50   return SmallVector<Value, 2>({input2, input1});
51 }
52 
53 // Test that natives calls are only called once during rewrites.
54 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
55 // This let us check the number of times OpM_Test was called by inspecting
56 // the returned value in the MLIR output.
57 static int64_t opMIncreasingValue = 314159265;
58 static Attribute OpMTest(PatternRewriter &rewriter, Value val) {
59   int64_t i = opMIncreasingValue++;
60   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
61 }
62 
63 namespace {
64 #include "TestPatterns.inc"
65 } // end anonymous namespace
66 
67 //===----------------------------------------------------------------------===//
68 // Test Reduce Pattern Interface
69 //===----------------------------------------------------------------------===//
70 
71 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
72   populateWithGenerated(patterns);
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // Canonicalizer Driver.
77 //===----------------------------------------------------------------------===//
78 
79 namespace {
80 struct FoldingPattern : public RewritePattern {
81 public:
82   FoldingPattern(MLIRContext *context)
83       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
84                        /*benefit=*/1, context) {}
85 
86   LogicalResult matchAndRewrite(Operation *op,
87                                 PatternRewriter &rewriter) const override {
88     // Exercise OperationFolder API for a single-result operation that is folded
89     // upon construction. The operation being created through the folder has an
90     // in-place folder, and it should be still present in the output.
91     // Furthermore, the folder should not crash when attempting to recover the
92     // (unchanged) operation result.
93     OperationFolder folder(op->getContext());
94     Value result = folder.create<TestOpInPlaceFold>(
95         rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
96         rewriter.getI32IntegerAttr(0));
97     assert(result);
98     rewriter.replaceOp(op, result);
99     return success();
100   }
101 };
102 
103 /// This pattern creates a foldable operation at the entry point of the block.
104 /// This tests the situation where the operation folder will need to replace an
105 /// operation with a previously created constant that does not initially
106 /// dominate the operation to replace.
107 struct FolderInsertBeforePreviouslyFoldedConstantPattern
108     : public OpRewritePattern<TestCastOp> {
109 public:
110   using OpRewritePattern<TestCastOp>::OpRewritePattern;
111 
112   LogicalResult matchAndRewrite(TestCastOp op,
113                                 PatternRewriter &rewriter) const override {
114     if (!op->hasAttr("test_fold_before_previously_folded_op"))
115       return failure();
116     rewriter.setInsertionPointToStart(op->getBlock());
117 
118     auto constOp = rewriter.create<arith::ConstantOp>(
119         op.getLoc(), rewriter.getBoolAttr(true));
120     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
121                                             Value(constOp));
122     return success();
123   }
124 };
125 
126 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
127   StringRef getArgument() const final { return "test-patterns"; }
128   StringRef getDescription() const final { return "Run test dialect patterns"; }
129   void runOnFunction() override {
130     mlir::RewritePatternSet patterns(&getContext());
131     populateWithGenerated(patterns);
132 
133     // Verify named pattern is generated with expected name.
134     patterns.add<FoldingPattern, TestNamedPatternRule,
135                  FolderInsertBeforePreviouslyFoldedConstantPattern>(
136         &getContext());
137 
138     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
139   }
140 };
141 } // end anonymous namespace
142 
143 //===----------------------------------------------------------------------===//
144 // ReturnType Driver.
145 //===----------------------------------------------------------------------===//
146 
147 namespace {
148 // Generate ops for each instance where the type can be successfully inferred.
149 template <typename OpTy>
150 static void invokeCreateWithInferredReturnType(Operation *op) {
151   auto *context = op->getContext();
152   auto fop = op->getParentOfType<FuncOp>();
153   auto location = UnknownLoc::get(context);
154   OpBuilder b(op);
155   b.setInsertionPointAfter(op);
156 
157   // Use permutations of 2 args as operands.
158   assert(fop.getNumArguments() >= 2);
159   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
160     for (int j = 0; j < e; ++j) {
161       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
162       SmallVector<Type, 2> inferredReturnTypes;
163       if (succeeded(OpTy::inferReturnTypes(
164               context, llvm::None, values, op->getAttrDictionary(),
165               op->getRegions(), inferredReturnTypes))) {
166         OperationState state(location, OpTy::getOperationName());
167         // TODO: Expand to regions.
168         OpTy::build(b, state, values, op->getAttrs());
169         (void)b.createOperation(state);
170       }
171     }
172   }
173 }
174 
175 static void reifyReturnShape(Operation *op) {
176   OpBuilder b(op);
177 
178   // Use permutations of 2 args as operands.
179   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
180   SmallVector<Value, 2> shapes;
181   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
182       !llvm::hasSingleElement(shapes))
183     return;
184   for (auto it : llvm::enumerate(shapes)) {
185     op->emitRemark() << "value " << it.index() << ": "
186                      << it.value().getDefiningOp();
187   }
188 }
189 
190 struct TestReturnTypeDriver
191     : public PassWrapper<TestReturnTypeDriver, FunctionPass> {
192   void getDependentDialects(DialectRegistry &registry) const override {
193     registry.insert<tensor::TensorDialect>();
194   }
195   StringRef getArgument() const final { return "test-return-type"; }
196   StringRef getDescription() const final { return "Run return type functions"; }
197 
198   void runOnFunction() override {
199     if (getFunction().getName() == "testCreateFunctions") {
200       std::vector<Operation *> ops;
201       // Collect ops to avoid triggering on inserted ops.
202       for (auto &op : getFunction().getBody().front())
203         ops.push_back(&op);
204       // Generate test patterns for each, but skip terminator.
205       for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
206         // Test create method of each of the Op classes below. The resultant
207         // output would be in reverse order underneath `op` from which
208         // the attributes and regions are used.
209         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
210         invokeCreateWithInferredReturnType<
211             OpWithShapedTypeInferTypeInterfaceOp>(op);
212       };
213       return;
214     }
215     if (getFunction().getName() == "testReifyFunctions") {
216       std::vector<Operation *> ops;
217       // Collect ops to avoid triggering on inserted ops.
218       for (auto &op : getFunction().getBody().front())
219         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
220           ops.push_back(&op);
221       // Generate test patterns for each, but skip terminator.
222       for (auto *op : ops)
223         reifyReturnShape(op);
224     }
225   }
226 };
227 } // end anonymous namespace
228 
229 namespace {
230 struct TestDerivedAttributeDriver
231     : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> {
232   StringRef getArgument() const final { return "test-derived-attr"; }
233   StringRef getDescription() const final {
234     return "Run test derived attributes";
235   }
236   void runOnFunction() override;
237 };
238 } // end anonymous namespace
239 
240 void TestDerivedAttributeDriver::runOnFunction() {
241   getFunction().walk([](DerivedAttributeOpInterface dOp) {
242     auto dAttr = dOp.materializeDerivedAttributes();
243     if (!dAttr)
244       return;
245     for (auto d : dAttr)
246       dOp.emitRemark() << d.first.getValue() << " = " << d.second;
247   });
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // Legalization Driver.
252 //===----------------------------------------------------------------------===//
253 
254 namespace {
255 //===----------------------------------------------------------------------===//
256 // Region-Block Rewrite Testing
257 
258 /// This pattern is a simple pattern that inlines the first region of a given
259 /// operation into the parent region.
260 struct TestRegionRewriteBlockMovement : public ConversionPattern {
261   TestRegionRewriteBlockMovement(MLIRContext *ctx)
262       : ConversionPattern("test.region", 1, ctx) {}
263 
264   LogicalResult
265   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
266                   ConversionPatternRewriter &rewriter) const final {
267     // Inline this region into the parent region.
268     auto &parentRegion = *op->getParentRegion();
269     auto &opRegion = op->getRegion(0);
270     if (op->getAttr("legalizer.should_clone"))
271       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
272     else
273       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
274 
275     if (op->getAttr("legalizer.erase_old_blocks")) {
276       while (!opRegion.empty())
277         rewriter.eraseBlock(&opRegion.front());
278     }
279 
280     // Drop this operation.
281     rewriter.eraseOp(op);
282     return success();
283   }
284 };
285 /// This pattern is a simple pattern that generates a region containing an
286 /// illegal operation.
287 struct TestRegionRewriteUndo : public RewritePattern {
288   TestRegionRewriteUndo(MLIRContext *ctx)
289       : RewritePattern("test.region_builder", 1, ctx) {}
290 
291   LogicalResult matchAndRewrite(Operation *op,
292                                 PatternRewriter &rewriter) const final {
293     // Create the region operation with an entry block containing arguments.
294     OperationState newRegion(op->getLoc(), "test.region");
295     newRegion.addRegion();
296     auto *regionOp = rewriter.createOperation(newRegion);
297     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
298     entryBlock->addArgument(rewriter.getIntegerType(64));
299 
300     // Add an explicitly illegal operation to ensure the conversion fails.
301     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
302     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
303 
304     // Drop this operation.
305     rewriter.eraseOp(op);
306     return success();
307   }
308 };
309 /// A simple pattern that creates a block at the end of the parent region of the
310 /// matched operation.
311 struct TestCreateBlock : public RewritePattern {
312   TestCreateBlock(MLIRContext *ctx)
313       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
314 
315   LogicalResult matchAndRewrite(Operation *op,
316                                 PatternRewriter &rewriter) const final {
317     Region &region = *op->getParentRegion();
318     Type i32Type = rewriter.getIntegerType(32);
319     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
320     rewriter.create<TerminatorOp>(op->getLoc());
321     rewriter.replaceOp(op, {});
322     return success();
323   }
324 };
325 
326 /// A simple pattern that creates a block containing an invalid operation in
327 /// order to trigger the block creation undo mechanism.
328 struct TestCreateIllegalBlock : public RewritePattern {
329   TestCreateIllegalBlock(MLIRContext *ctx)
330       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
331 
332   LogicalResult matchAndRewrite(Operation *op,
333                                 PatternRewriter &rewriter) const final {
334     Region &region = *op->getParentRegion();
335     Type i32Type = rewriter.getIntegerType(32);
336     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
337     // Create an illegal op to ensure the conversion fails.
338     rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
339     rewriter.create<TerminatorOp>(op->getLoc());
340     rewriter.replaceOp(op, {});
341     return success();
342   }
343 };
344 
345 /// A simple pattern that tests the undo mechanism when replacing the uses of a
346 /// block argument.
347 struct TestUndoBlockArgReplace : public ConversionPattern {
348   TestUndoBlockArgReplace(MLIRContext *ctx)
349       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
350 
351   LogicalResult
352   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
353                   ConversionPatternRewriter &rewriter) const final {
354     auto illegalOp =
355         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
356     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
357                                         illegalOp);
358     rewriter.updateRootInPlace(op, [] {});
359     return success();
360   }
361 };
362 
363 /// A rewrite pattern that tests the undo mechanism when erasing a block.
364 struct TestUndoBlockErase : public ConversionPattern {
365   TestUndoBlockErase(MLIRContext *ctx)
366       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
367 
368   LogicalResult
369   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
370                   ConversionPatternRewriter &rewriter) const final {
371     Block *secondBlock = &*std::next(op->getRegion(0).begin());
372     rewriter.setInsertionPointToStart(secondBlock);
373     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
374     rewriter.eraseBlock(secondBlock);
375     rewriter.updateRootInPlace(op, [] {});
376     return success();
377   }
378 };
379 
380 //===----------------------------------------------------------------------===//
381 // Type-Conversion Rewrite Testing
382 
383 /// This patterns erases a region operation that has had a type conversion.
384 struct TestDropOpSignatureConversion : public ConversionPattern {
385   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
386       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
387   LogicalResult
388   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
389                   ConversionPatternRewriter &rewriter) const override {
390     Region &region = op->getRegion(0);
391     Block *entry = &region.front();
392 
393     // Convert the original entry arguments.
394     TypeConverter &converter = *getTypeConverter();
395     TypeConverter::SignatureConversion result(entry->getNumArguments());
396     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
397                                               result)) ||
398         failed(rewriter.convertRegionTypes(&region, converter, &result)))
399       return failure();
400 
401     // Convert the region signature and just drop the operation.
402     rewriter.eraseOp(op);
403     return success();
404   }
405 };
406 /// This pattern simply updates the operands of the given operation.
407 struct TestPassthroughInvalidOp : public ConversionPattern {
408   TestPassthroughInvalidOp(MLIRContext *ctx)
409       : ConversionPattern("test.invalid", 1, ctx) {}
410   LogicalResult
411   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
412                   ConversionPatternRewriter &rewriter) const final {
413     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
414                                              llvm::None);
415     return success();
416   }
417 };
418 /// This pattern handles the case of a split return value.
419 struct TestSplitReturnType : public ConversionPattern {
420   TestSplitReturnType(MLIRContext *ctx)
421       : ConversionPattern("test.return", 1, ctx) {}
422   LogicalResult
423   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
424                   ConversionPatternRewriter &rewriter) const final {
425     // Check for a return of F32.
426     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
427       return failure();
428 
429     // Check if the first operation is a cast operation, if it is we use the
430     // results directly.
431     auto *defOp = operands[0].getDefiningOp();
432     if (auto packerOp =
433             llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
434       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
435       return success();
436     }
437 
438     // Otherwise, fail to match.
439     return failure();
440   }
441 };
442 
443 //===----------------------------------------------------------------------===//
444 // Multi-Level Type-Conversion Rewrite Testing
445 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
446   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
447       : ConversionPattern("test.type_producer", 1, ctx) {}
448   LogicalResult
449   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
450                   ConversionPatternRewriter &rewriter) const final {
451     // If the type is I32, change the type to F32.
452     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
453       return failure();
454     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
455     return success();
456   }
457 };
458 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
459   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
460       : ConversionPattern("test.type_producer", 1, ctx) {}
461   LogicalResult
462   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
463                   ConversionPatternRewriter &rewriter) const final {
464     // If the type is F32, change the type to F64.
465     if (!Type(*op->result_type_begin()).isF32())
466       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
467     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
468     return success();
469   }
470 };
471 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
472   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
473       : ConversionPattern("test.type_producer", 10, ctx) {}
474   LogicalResult
475   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
476                   ConversionPatternRewriter &rewriter) const final {
477     // Always convert to B16, even though it is not a legal type. This tests
478     // that values are unmapped correctly.
479     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
480     return success();
481   }
482 };
483 struct TestUpdateConsumerType : public ConversionPattern {
484   TestUpdateConsumerType(MLIRContext *ctx)
485       : ConversionPattern("test.type_consumer", 1, ctx) {}
486   LogicalResult
487   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
488                   ConversionPatternRewriter &rewriter) const final {
489     // Verify that the incoming operand has been successfully remapped to F64.
490     if (!operands[0].getType().isF64())
491       return failure();
492     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
493     return success();
494   }
495 };
496 
497 //===----------------------------------------------------------------------===//
498 // Non-Root Replacement Rewrite Testing
499 /// This pattern generates an invalid operation, but replaces it before the
500 /// pattern is finished. This checks that we don't need to legalize the
501 /// temporary op.
502 struct TestNonRootReplacement : public RewritePattern {
503   TestNonRootReplacement(MLIRContext *ctx)
504       : RewritePattern("test.replace_non_root", 1, ctx) {}
505 
506   LogicalResult matchAndRewrite(Operation *op,
507                                 PatternRewriter &rewriter) const final {
508     auto resultType = *op->result_type_begin();
509     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
510     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
511 
512     rewriter.replaceOp(illegalOp, {legalOp});
513     rewriter.replaceOp(op, {illegalOp});
514     return success();
515   }
516 };
517 
518 //===----------------------------------------------------------------------===//
519 // Recursive Rewrite Testing
520 /// This pattern is applied to the same operation multiple times, but has a
521 /// bounded recursion.
522 struct TestBoundedRecursiveRewrite
523     : public OpRewritePattern<TestRecursiveRewriteOp> {
524   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
525 
526   void initialize() {
527     // The conversion target handles bounding the recursion of this pattern.
528     setHasBoundedRewriteRecursion();
529   }
530 
531   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
532                                 PatternRewriter &rewriter) const final {
533     // Decrement the depth of the op in-place.
534     rewriter.updateRootInPlace(op, [&] {
535       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
536     });
537     return success();
538   }
539 };
540 
541 struct TestNestedOpCreationUndoRewrite
542     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
543   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
544 
545   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
546                                 PatternRewriter &rewriter) const final {
547     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
548     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
549     return success();
550   };
551 };
552 
553 // This pattern matches `test.blackhole` and delete this op and its producer.
554 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
555   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
556 
557   LogicalResult matchAndRewrite(BlackHoleOp op,
558                                 PatternRewriter &rewriter) const final {
559     Operation *producer = op.getOperand().getDefiningOp();
560     // Always erase the user before the producer, the framework should handle
561     // this correctly.
562     rewriter.eraseOp(op);
563     rewriter.eraseOp(producer);
564     return success();
565   };
566 };
567 
568 // This pattern replaces explicitly illegal op with explicitly legal op,
569 // but in addition creates unregistered operation.
570 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
571   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
572 
573   LogicalResult matchAndRewrite(ILLegalOpG op,
574                                 PatternRewriter &rewriter) const final {
575     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
576     Value val = rewriter.create<ConstantOp>(op->getLoc(), attr);
577     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
578     return success();
579   };
580 };
581 } // namespace
582 
583 namespace {
584 struct TestTypeConverter : public TypeConverter {
585   using TypeConverter::TypeConverter;
586   TestTypeConverter() {
587     addConversion(convertType);
588     addArgumentMaterialization(materializeCast);
589     addSourceMaterialization(materializeCast);
590   }
591 
592   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
593     // Drop I16 types.
594     if (t.isSignlessInteger(16))
595       return success();
596 
597     // Convert I64 to F64.
598     if (t.isSignlessInteger(64)) {
599       results.push_back(FloatType::getF64(t.getContext()));
600       return success();
601     }
602 
603     // Convert I42 to I43.
604     if (t.isInteger(42)) {
605       results.push_back(IntegerType::get(t.getContext(), 43));
606       return success();
607     }
608 
609     // Split F32 into F16,F16.
610     if (t.isF32()) {
611       results.assign(2, FloatType::getF16(t.getContext()));
612       return success();
613     }
614 
615     // Otherwise, convert the type directly.
616     results.push_back(t);
617     return success();
618   }
619 
620   /// Hook for materializing a conversion. This is necessary because we generate
621   /// 1->N type mappings.
622   static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
623                                          ValueRange inputs, Location loc) {
624     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
625   }
626 };
627 
628 struct TestLegalizePatternDriver
629     : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
630   StringRef getArgument() const final { return "test-legalize-patterns"; }
631   StringRef getDescription() const final {
632     return "Run test dialect legalization patterns";
633   }
634   /// The mode of conversion to use with the driver.
635   enum class ConversionMode { Analysis, Full, Partial };
636 
637   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
638 
639   void getDependentDialects(DialectRegistry &registry) const override {
640     registry.insert<StandardOpsDialect>();
641   }
642 
643   void runOnOperation() override {
644     TestTypeConverter converter;
645     mlir::RewritePatternSet patterns(&getContext());
646     populateWithGenerated(patterns);
647     patterns
648         .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
649              TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace,
650              TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType,
651              TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
652              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
653              TestNonRootReplacement, TestBoundedRecursiveRewrite,
654              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
655              TestCreateUnregisteredOp>(&getContext());
656     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
657     mlir::populateFuncOpTypeConversionPattern(patterns, converter);
658     mlir::populateCallOpTypeConversionPattern(patterns, converter);
659 
660     // Define the conversion target used for the test.
661     ConversionTarget target(getContext());
662     target.addLegalOp<ModuleOp>();
663     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
664                       TerminatorOp>();
665     target
666         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
667     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
668       // Don't allow F32 operands.
669       return llvm::none_of(op.getOperandTypes(),
670                            [](Type type) { return type.isF32(); });
671     });
672     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
673       return converter.isSignatureLegal(op.getType()) &&
674              converter.isLegal(&op.getBody());
675     });
676     target.addDynamicallyLegalOp<CallOp>(
677         [&](CallOp op) { return converter.isLegal(op); });
678 
679     // TestCreateUnregisteredOp creates `arith.constant` operation,
680     // which was not added to target intentionally to test
681     // correct error code from conversion driver.
682     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
683 
684     // Expect the type_producer/type_consumer operations to only operate on f64.
685     target.addDynamicallyLegalOp<TestTypeProducerOp>(
686         [](TestTypeProducerOp op) { return op.getType().isF64(); });
687     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
688       return op.getOperand().getType().isF64();
689     });
690 
691     // Check support for marking certain operations as recursively legal.
692     target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) {
693       return static_cast<bool>(
694           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
695     });
696 
697     // Mark the bound recursion operation as dynamically legal.
698     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
699         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
700 
701     // Handle a partial conversion.
702     if (mode == ConversionMode::Partial) {
703       DenseSet<Operation *> unlegalizedOps;
704       if (failed(applyPartialConversion(
705               getOperation(), target, std::move(patterns), &unlegalizedOps))) {
706         getOperation()->emitRemark() << "applyPartialConversion failed";
707       }
708       // Emit remarks for each legalizable operation.
709       for (auto *op : unlegalizedOps)
710         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
711       return;
712     }
713 
714     // Handle a full conversion.
715     if (mode == ConversionMode::Full) {
716       // Check support for marking unknown operations as dynamically legal.
717       target.markUnknownOpDynamicallyLegal([](Operation *op) {
718         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
719       });
720 
721       if (failed(applyFullConversion(getOperation(), target,
722                                      std::move(patterns)))) {
723         getOperation()->emitRemark() << "applyFullConversion failed";
724       }
725       return;
726     }
727 
728     // Otherwise, handle an analysis conversion.
729     assert(mode == ConversionMode::Analysis);
730 
731     // Analyze the convertible operations.
732     DenseSet<Operation *> legalizedOps;
733     if (failed(applyAnalysisConversion(getOperation(), target,
734                                        std::move(patterns), legalizedOps)))
735       return signalPassFailure();
736 
737     // Emit remarks for each legalizable operation.
738     for (auto *op : legalizedOps)
739       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
740   }
741 
742   /// The mode of conversion to use.
743   ConversionMode mode;
744 };
745 } // end anonymous namespace
746 
747 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
748     legalizerConversionMode(
749         "test-legalize-mode",
750         llvm::cl::desc("The legalization mode to use with the test driver"),
751         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
752         llvm::cl::values(
753             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
754                        "analysis", "Perform an analysis conversion"),
755             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
756                        "Perform a full conversion"),
757             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
758                        "partial", "Perform a partial conversion")));
759 
760 //===----------------------------------------------------------------------===//
761 // ConversionPatternRewriter::getRemappedValue testing. This method is used
762 // to get the remapped value of an original value that was replaced using
763 // ConversionPatternRewriter.
764 namespace {
765 struct TestRemapValueTypeConverter : public TypeConverter {
766   using TypeConverter::TypeConverter;
767 
768   TestRemapValueTypeConverter() {
769     addConversion(
770         [](Float32Type type) { return Float64Type::get(type.getContext()); });
771     addConversion([](Type type) { return type; });
772   }
773 };
774 
775 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
776 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
777 /// operand twice.
778 ///
779 /// Example:
780 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
781 /// is replaced with:
782 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
783 struct OneVResOneVOperandOp1Converter
784     : public OpConversionPattern<OneVResOneVOperandOp1> {
785   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
786 
787   LogicalResult
788   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
789                   ConversionPatternRewriter &rewriter) const override {
790     auto origOps = op.getOperands();
791     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
792            "One operand expected");
793     Value origOp = *origOps.begin();
794     SmallVector<Value, 2> remappedOperands;
795     // Replicate the remapped original operand twice. Note that we don't used
796     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
797     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
798     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
799 
800     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
801                                                        remappedOperands);
802     return success();
803   }
804 };
805 
806 /// A rewriter pattern that tests that blocks can be merged.
807 struct TestRemapValueInRegion
808     : public OpConversionPattern<TestRemappedValueRegionOp> {
809   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
810 
811   LogicalResult
812   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
813                   ConversionPatternRewriter &rewriter) const final {
814     Block &block = op.getBody().front();
815     Operation *terminator = block.getTerminator();
816 
817     // Merge the block into the parent region.
818     Block *parentBlock = op->getBlock();
819     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
820     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
821     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
822 
823     // Replace the results of this operation with the remapped terminator
824     // values.
825     SmallVector<Value> terminatorOperands;
826     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
827                                           terminatorOperands)))
828       return failure();
829 
830     rewriter.eraseOp(terminator);
831     rewriter.replaceOp(op, terminatorOperands);
832     return success();
833   }
834 };
835 
836 struct TestRemappedValue
837     : public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
838   StringRef getArgument() const final { return "test-remapped-value"; }
839   StringRef getDescription() const final {
840     return "Test public remapped value mechanism in ConversionPatternRewriter";
841   }
842   void runOnFunction() override {
843     TestRemapValueTypeConverter typeConverter;
844 
845     mlir::RewritePatternSet patterns(&getContext());
846     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
847     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
848         &getContext());
849     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
850 
851     mlir::ConversionTarget target(getContext());
852     target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>();
853 
854     // Expect the type_producer/type_consumer operations to only operate on f64.
855     target.addDynamicallyLegalOp<TestTypeProducerOp>(
856         [](TestTypeProducerOp op) { return op.getType().isF64(); });
857     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
858       return op.getOperand().getType().isF64();
859     });
860 
861     // We make OneVResOneVOperandOp1 legal only when it has more that one
862     // operand. This will trigger the conversion that will replace one-operand
863     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
864     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
865         [](Operation *op) { return op->getNumOperands() > 1; });
866 
867     if (failed(mlir::applyFullConversion(getFunction(), target,
868                                          std::move(patterns)))) {
869       signalPassFailure();
870     }
871   }
872 };
873 } // end anonymous namespace
874 
875 //===----------------------------------------------------------------------===//
876 // Test patterns without a specific root operation kind
877 //===----------------------------------------------------------------------===//
878 
879 namespace {
880 /// This pattern matches and removes any operation in the test dialect.
881 struct RemoveTestDialectOps : public RewritePattern {
882   RemoveTestDialectOps(MLIRContext *context)
883       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
884 
885   LogicalResult matchAndRewrite(Operation *op,
886                                 PatternRewriter &rewriter) const override {
887     if (!isa<TestDialect>(op->getDialect()))
888       return failure();
889     rewriter.eraseOp(op);
890     return success();
891   }
892 };
893 
894 struct TestUnknownRootOpDriver
895     : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
896   StringRef getArgument() const final {
897     return "test-legalize-unknown-root-patterns";
898   }
899   StringRef getDescription() const final {
900     return "Test public remapped value mechanism in ConversionPatternRewriter";
901   }
902   void runOnFunction() override {
903     mlir::RewritePatternSet patterns(&getContext());
904     patterns.add<RemoveTestDialectOps>(&getContext());
905 
906     mlir::ConversionTarget target(getContext());
907     target.addIllegalDialect<TestDialect>();
908     if (failed(
909             applyPartialConversion(getFunction(), target, std::move(patterns))))
910       signalPassFailure();
911   }
912 };
913 } // end anonymous namespace
914 
915 //===----------------------------------------------------------------------===//
916 // Test type conversions
917 //===----------------------------------------------------------------------===//
918 
919 namespace {
920 struct TestTypeConversionProducer
921     : public OpConversionPattern<TestTypeProducerOp> {
922   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
923   LogicalResult
924   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
925                   ConversionPatternRewriter &rewriter) const final {
926     Type resultType = op.getType();
927     if (resultType.isa<FloatType>())
928       resultType = rewriter.getF64Type();
929     else if (resultType.isInteger(16))
930       resultType = rewriter.getIntegerType(64);
931     else
932       return failure();
933 
934     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
935     return success();
936   }
937 };
938 
939 /// Call signature conversion and then fail the rewrite to trigger the undo
940 /// mechanism.
941 struct TestSignatureConversionUndo
942     : public OpConversionPattern<TestSignatureConversionUndoOp> {
943   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
944 
945   LogicalResult
946   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
947                   ConversionPatternRewriter &rewriter) const final {
948     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
949     return failure();
950   }
951 };
952 
953 /// Call signature conversion without providing a type converter to handle
954 /// materializations.
955 struct TestTestSignatureConversionNoConverter
956     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
957   TestTestSignatureConversionNoConverter(TypeConverter &converter,
958                                          MLIRContext *context)
959       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
960         converter(converter) {}
961 
962   LogicalResult
963   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
964                   ConversionPatternRewriter &rewriter) const final {
965     Region &region = op->getRegion(0);
966     Block *entry = &region.front();
967 
968     // Convert the original entry arguments.
969     TypeConverter::SignatureConversion result(entry->getNumArguments());
970     if (failed(
971             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
972       return failure();
973     rewriter.updateRootInPlace(
974         op, [&] { rewriter.applySignatureConversion(&region, result); });
975     return success();
976   }
977 
978   TypeConverter &converter;
979 };
980 
981 /// Just forward the operands to the root op. This is essentially a no-op
982 /// pattern that is used to trigger target materialization.
983 struct TestTypeConsumerForward
984     : public OpConversionPattern<TestTypeConsumerOp> {
985   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
986 
987   LogicalResult
988   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
989                   ConversionPatternRewriter &rewriter) const final {
990     rewriter.updateRootInPlace(op,
991                                [&] { op->setOperands(adaptor.getOperands()); });
992     return success();
993   }
994 };
995 
996 struct TestTypeConversionAnotherProducer
997     : public OpRewritePattern<TestAnotherTypeProducerOp> {
998   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
999 
1000   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1001                                 PatternRewriter &rewriter) const final {
1002     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1003     return success();
1004   }
1005 };
1006 
1007 struct TestTypeConversionDriver
1008     : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
1009   void getDependentDialects(DialectRegistry &registry) const override {
1010     registry.insert<TestDialect>();
1011   }
1012   StringRef getArgument() const final {
1013     return "test-legalize-type-conversion";
1014   }
1015   StringRef getDescription() const final {
1016     return "Test various type conversion functionalities in DialectConversion";
1017   }
1018 
1019   void runOnOperation() override {
1020     // Initialize the type converter.
1021     TypeConverter converter;
1022 
1023     /// Add the legal set of type conversions.
1024     converter.addConversion([](Type type) -> Type {
1025       // Treat F64 as legal.
1026       if (type.isF64())
1027         return type;
1028       // Allow converting BF16/F16/F32 to F64.
1029       if (type.isBF16() || type.isF16() || type.isF32())
1030         return FloatType::getF64(type.getContext());
1031       // Otherwise, the type is illegal.
1032       return nullptr;
1033     });
1034     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1035       // Drop all integer types.
1036       return success();
1037     });
1038 
1039     /// Add the legal set of type materializations.
1040     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1041                                           ValueRange inputs,
1042                                           Location loc) -> Value {
1043       // Allow casting from F64 back to F32.
1044       if (!resultType.isF16() && inputs.size() == 1 &&
1045           inputs[0].getType().isF64())
1046         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1047       // Allow producing an i32 or i64 from nothing.
1048       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1049           inputs.empty())
1050         return builder.create<TestTypeProducerOp>(loc, resultType);
1051       // Allow producing an i64 from an integer.
1052       if (resultType.isa<IntegerType>() && inputs.size() == 1 &&
1053           inputs[0].getType().isa<IntegerType>())
1054         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1055       // Otherwise, fail.
1056       return nullptr;
1057     });
1058 
1059     // Initialize the conversion target.
1060     mlir::ConversionTarget target(getContext());
1061     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1062       return op.getType().isF64() || op.getType().isInteger(64);
1063     });
1064     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
1065       return converter.isSignatureLegal(op.getType()) &&
1066              converter.isLegal(&op.getBody());
1067     });
1068     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1069       // Allow casts from F64 to F32.
1070       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1071     });
1072     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1073         [&](TestSignatureConversionNoConverterOp op) {
1074           return converter.isLegal(op.getRegion().front().getArgumentTypes());
1075         });
1076 
1077     // Initialize the set of rewrite patterns.
1078     RewritePatternSet patterns(&getContext());
1079     patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1080                  TestSignatureConversionUndo,
1081                  TestTestSignatureConversionNoConverter>(converter,
1082                                                          &getContext());
1083     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
1084     mlir::populateFuncOpTypeConversionPattern(patterns, converter);
1085 
1086     if (failed(applyPartialConversion(getOperation(), target,
1087                                       std::move(patterns))))
1088       signalPassFailure();
1089   }
1090 };
1091 } // end anonymous namespace
1092 
1093 //===----------------------------------------------------------------------===//
1094 // Test Block Merging
1095 //===----------------------------------------------------------------------===//
1096 
1097 namespace {
1098 /// A rewriter pattern that tests that blocks can be merged.
1099 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1100   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1101 
1102   LogicalResult
1103   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1104                   ConversionPatternRewriter &rewriter) const final {
1105     Block &firstBlock = op.getBody().front();
1106     Operation *branchOp = firstBlock.getTerminator();
1107     Block *secondBlock = &*(std::next(op.getBody().begin()));
1108     auto succOperands = branchOp->getOperands();
1109     SmallVector<Value, 2> replacements(succOperands);
1110     rewriter.eraseOp(branchOp);
1111     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1112     rewriter.updateRootInPlace(op, [] {});
1113     return success();
1114   }
1115 };
1116 
1117 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1118 struct TestUndoBlocksMerge : public ConversionPattern {
1119   TestUndoBlocksMerge(MLIRContext *ctx)
1120       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1121   LogicalResult
1122   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1123                   ConversionPatternRewriter &rewriter) const final {
1124     Block &firstBlock = op->getRegion(0).front();
1125     Operation *branchOp = firstBlock.getTerminator();
1126     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1127     rewriter.setInsertionPointToStart(secondBlock);
1128     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1129     auto succOperands = branchOp->getOperands();
1130     SmallVector<Value, 2> replacements(succOperands);
1131     rewriter.eraseOp(branchOp);
1132     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1133     rewriter.updateRootInPlace(op, [] {});
1134     return success();
1135   }
1136 };
1137 
1138 /// A rewrite mechanism to inline the body of the op into its parent, when both
1139 /// ops can have a single block.
1140 struct TestMergeSingleBlockOps
1141     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1142   using OpConversionPattern<
1143       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1144 
1145   LogicalResult
1146   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1147                   ConversionPatternRewriter &rewriter) const final {
1148     SingleBlockImplicitTerminatorOp parentOp =
1149         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1150     if (!parentOp)
1151       return failure();
1152     Block &innerBlock = op.getRegion().front();
1153     TerminatorOp innerTerminator =
1154         cast<TerminatorOp>(innerBlock.getTerminator());
1155     rewriter.mergeBlockBefore(&innerBlock, op);
1156     rewriter.eraseOp(innerTerminator);
1157     rewriter.eraseOp(op);
1158     rewriter.updateRootInPlace(op, [] {});
1159     return success();
1160   }
1161 };
1162 
1163 struct TestMergeBlocksPatternDriver
1164     : public PassWrapper<TestMergeBlocksPatternDriver,
1165                          OperationPass<ModuleOp>> {
1166   StringRef getArgument() const final { return "test-merge-blocks"; }
1167   StringRef getDescription() const final {
1168     return "Test Merging operation in ConversionPatternRewriter";
1169   }
1170   void runOnOperation() override {
1171     MLIRContext *context = &getContext();
1172     mlir::RewritePatternSet patterns(context);
1173     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1174         context);
1175     ConversionTarget target(*context);
1176     target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1177                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1178     target.addIllegalOp<ILLegalOpF>();
1179 
1180     /// Expect the op to have a single block after legalization.
1181     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1182         [&](TestMergeBlocksOp op) -> bool {
1183           return llvm::hasSingleElement(op.getBody());
1184         });
1185 
1186     /// Only allow `test.br` within test.merge_blocks op.
1187     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
1188       return op->getParentOfType<TestMergeBlocksOp>();
1189     });
1190 
1191     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1192     /// inlined.
1193     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1194         [&](SingleBlockImplicitTerminatorOp op) -> bool {
1195           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1196         });
1197 
1198     DenseSet<Operation *> unlegalizedOps;
1199     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1200                                  &unlegalizedOps);
1201     for (auto *op : unlegalizedOps)
1202       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1203   }
1204 };
1205 } // namespace
1206 
1207 //===----------------------------------------------------------------------===//
1208 // Test Selective Replacement
1209 //===----------------------------------------------------------------------===//
1210 
1211 namespace {
1212 /// A rewrite mechanism to inline the body of the op into its parent, when both
1213 /// ops can have a single block.
1214 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
1215   using OpRewritePattern<TestCastOp>::OpRewritePattern;
1216 
1217   LogicalResult matchAndRewrite(TestCastOp op,
1218                                 PatternRewriter &rewriter) const final {
1219     if (op.getNumOperands() != 2)
1220       return failure();
1221     OperandRange operands = op.getOperands();
1222 
1223     // Replace non-terminator uses with the first operand.
1224     rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
1225       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
1226     });
1227     // Replace everything else with the second operand if the operation isn't
1228     // dead.
1229     rewriter.replaceOp(op, op.getOperand(1));
1230     return success();
1231   }
1232 };
1233 
1234 struct TestSelectiveReplacementPatternDriver
1235     : public PassWrapper<TestSelectiveReplacementPatternDriver,
1236                          OperationPass<>> {
1237   StringRef getArgument() const final {
1238     return "test-pattern-selective-replacement";
1239   }
1240   StringRef getDescription() const final {
1241     return "Test selective replacement in the PatternRewriter";
1242   }
1243   void runOnOperation() override {
1244     MLIRContext *context = &getContext();
1245     mlir::RewritePatternSet patterns(context);
1246     patterns.add<TestSelectiveOpReplacementPattern>(context);
1247     (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
1248                                        std::move(patterns));
1249   }
1250 };
1251 } // namespace
1252 
1253 //===----------------------------------------------------------------------===//
1254 // PassRegistration
1255 //===----------------------------------------------------------------------===//
1256 
1257 namespace mlir {
1258 namespace test {
1259 void registerPatternsTestPass() {
1260   PassRegistration<TestReturnTypeDriver>();
1261 
1262   PassRegistration<TestDerivedAttributeDriver>();
1263 
1264   PassRegistration<TestPatternDriver>();
1265 
1266   PassRegistration<TestLegalizePatternDriver>([] {
1267     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
1268   });
1269 
1270   PassRegistration<TestRemappedValue>();
1271 
1272   PassRegistration<TestUnknownRootOpDriver>();
1273 
1274   PassRegistration<TestTypeConversionDriver>();
1275 
1276   PassRegistration<TestMergeBlocksPatternDriver>();
1277   PassRegistration<TestSelectiveReplacementPatternDriver>();
1278 }
1279 } // namespace test
1280 } // namespace mlir
1281