1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "mlir/Transforms/FoldUtils.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace test;
23 
24 // Native function for testing NativeCodeCall
25 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
26   return choice.getValue() ? input1 : input2;
27 }
28 
29 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
30   rewriter.create<OpI>(loc, input);
31 }
32 
33 static void handleNoResultOp(PatternRewriter &rewriter,
34                              OpSymbolBindingNoResult op) {
35   // Turn the no result op to a one-result op.
36   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
37                                     op.getOperand());
38 }
39 
40 static bool getFirstI32Result(Operation *op, Value &value) {
41   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
42     return false;
43   value = op->getResult(0);
44   return true;
45 }
46 
47 static Value bindNativeCodeCallResult(Value value) { return value; }
48 
49 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
50                                                               Value input2) {
51   return SmallVector<Value, 2>({input2, input1});
52 }
53 
54 // Test that natives calls are only called once during rewrites.
55 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
56 // This let us check the number of times OpM_Test was called by inspecting
57 // the returned value in the MLIR output.
58 static int64_t opMIncreasingValue = 314159265;
59 static Attribute opMTest(PatternRewriter &rewriter, Value val) {
60   int64_t i = opMIncreasingValue++;
61   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
62 }
63 
64 namespace {
65 #include "TestPatterns.inc"
66 } // namespace
67 
68 //===----------------------------------------------------------------------===//
69 // Test Reduce Pattern Interface
70 //===----------------------------------------------------------------------===//
71 
72 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
73   populateWithGenerated(patterns);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Canonicalizer Driver.
78 //===----------------------------------------------------------------------===//
79 
80 namespace {
81 struct FoldingPattern : public RewritePattern {
82 public:
83   FoldingPattern(MLIRContext *context)
84       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
85                        /*benefit=*/1, context) {}
86 
87   LogicalResult matchAndRewrite(Operation *op,
88                                 PatternRewriter &rewriter) const override {
89     // Exercise OperationFolder API for a single-result operation that is folded
90     // upon construction. The operation being created through the folder has an
91     // in-place folder, and it should be still present in the output.
92     // Furthermore, the folder should not crash when attempting to recover the
93     // (unchanged) operation result.
94     OperationFolder folder(op->getContext());
95     Value result = folder.create<TestOpInPlaceFold>(
96         rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
97         rewriter.getI32IntegerAttr(0));
98     assert(result);
99     rewriter.replaceOp(op, result);
100     return success();
101   }
102 };
103 
104 /// This pattern creates a foldable operation at the entry point of the block.
105 /// This tests the situation where the operation folder will need to replace an
106 /// operation with a previously created constant that does not initially
107 /// dominate the operation to replace.
108 struct FolderInsertBeforePreviouslyFoldedConstantPattern
109     : public OpRewritePattern<TestCastOp> {
110 public:
111   using OpRewritePattern<TestCastOp>::OpRewritePattern;
112 
113   LogicalResult matchAndRewrite(TestCastOp op,
114                                 PatternRewriter &rewriter) const override {
115     if (!op->hasAttr("test_fold_before_previously_folded_op"))
116       return failure();
117     rewriter.setInsertionPointToStart(op->getBlock());
118 
119     auto constOp = rewriter.create<arith::ConstantOp>(
120         op.getLoc(), rewriter.getBoolAttr(true));
121     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
122                                             Value(constOp));
123     return success();
124   }
125 };
126 
127 /// This pattern matches test.op_commutative2 with the first operand being
128 /// another test.op_commutative2 with a constant on the right side and fold it
129 /// away by propagating it as its result. This is intend to check that patterns
130 /// are applied after the commutative property moves constant to the right.
131 struct FolderCommutativeOp2WithConstant
132     : public OpRewritePattern<TestCommutative2Op> {
133 public:
134   using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
135 
136   LogicalResult matchAndRewrite(TestCommutative2Op op,
137                                 PatternRewriter &rewriter) const override {
138     auto operand =
139         dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
140     if (!operand)
141       return failure();
142     Attribute constInput;
143     if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
144       return failure();
145     rewriter.replaceOp(op, operand->getOperand(1));
146     return success();
147   }
148 };
149 
150 struct TestPatternDriver
151     : public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> {
152   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
153 
154   StringRef getArgument() const final { return "test-patterns"; }
155   StringRef getDescription() const final { return "Run test dialect patterns"; }
156   void runOnOperation() override {
157     mlir::RewritePatternSet patterns(&getContext());
158     populateWithGenerated(patterns);
159 
160     // Verify named pattern is generated with expected name.
161     patterns.add<FoldingPattern, TestNamedPatternRule,
162                  FolderInsertBeforePreviouslyFoldedConstantPattern,
163                  FolderCommutativeOp2WithConstant>(&getContext());
164 
165     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
166   }
167 };
168 } // namespace
169 
170 //===----------------------------------------------------------------------===//
171 // ReturnType Driver.
172 //===----------------------------------------------------------------------===//
173 
174 namespace {
175 // Generate ops for each instance where the type can be successfully inferred.
176 template <typename OpTy>
177 static void invokeCreateWithInferredReturnType(Operation *op) {
178   auto *context = op->getContext();
179   auto fop = op->getParentOfType<func::FuncOp>();
180   auto location = UnknownLoc::get(context);
181   OpBuilder b(op);
182   b.setInsertionPointAfter(op);
183 
184   // Use permutations of 2 args as operands.
185   assert(fop.getNumArguments() >= 2);
186   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
187     for (int j = 0; j < e; ++j) {
188       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
189       SmallVector<Type, 2> inferredReturnTypes;
190       if (succeeded(OpTy::inferReturnTypes(
191               context, llvm::None, values, op->getAttrDictionary(),
192               op->getRegions(), inferredReturnTypes))) {
193         OperationState state(location, OpTy::getOperationName());
194         // TODO: Expand to regions.
195         OpTy::build(b, state, values, op->getAttrs());
196         (void)b.create(state);
197       }
198     }
199   }
200 }
201 
202 static void reifyReturnShape(Operation *op) {
203   OpBuilder b(op);
204 
205   // Use permutations of 2 args as operands.
206   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
207   SmallVector<Value, 2> shapes;
208   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
209       !llvm::hasSingleElement(shapes))
210     return;
211   for (const auto &it : llvm::enumerate(shapes)) {
212     op->emitRemark() << "value " << it.index() << ": "
213                      << it.value().getDefiningOp();
214   }
215 }
216 
217 struct TestReturnTypeDriver
218     : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
219   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
220 
221   void getDependentDialects(DialectRegistry &registry) const override {
222     registry.insert<tensor::TensorDialect>();
223   }
224   StringRef getArgument() const final { return "test-return-type"; }
225   StringRef getDescription() const final { return "Run return type functions"; }
226 
227   void runOnOperation() override {
228     if (getOperation().getName() == "testCreateFunctions") {
229       std::vector<Operation *> ops;
230       // Collect ops to avoid triggering on inserted ops.
231       for (auto &op : getOperation().getBody().front())
232         ops.push_back(&op);
233       // Generate test patterns for each, but skip terminator.
234       for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
235         // Test create method of each of the Op classes below. The resultant
236         // output would be in reverse order underneath `op` from which
237         // the attributes and regions are used.
238         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
239         invokeCreateWithInferredReturnType<
240             OpWithShapedTypeInferTypeInterfaceOp>(op);
241       };
242       return;
243     }
244     if (getOperation().getName() == "testReifyFunctions") {
245       std::vector<Operation *> ops;
246       // Collect ops to avoid triggering on inserted ops.
247       for (auto &op : getOperation().getBody().front())
248         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
249           ops.push_back(&op);
250       // Generate test patterns for each, but skip terminator.
251       for (auto *op : ops)
252         reifyReturnShape(op);
253     }
254   }
255 };
256 } // namespace
257 
258 namespace {
259 struct TestDerivedAttributeDriver
260     : public PassWrapper<TestDerivedAttributeDriver,
261                          OperationPass<func::FuncOp>> {
262   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
263 
264   StringRef getArgument() const final { return "test-derived-attr"; }
265   StringRef getDescription() const final {
266     return "Run test derived attributes";
267   }
268   void runOnOperation() override;
269 };
270 } // namespace
271 
272 void TestDerivedAttributeDriver::runOnOperation() {
273   getOperation().walk([](DerivedAttributeOpInterface dOp) {
274     auto dAttr = dOp.materializeDerivedAttributes();
275     if (!dAttr)
276       return;
277     for (auto d : dAttr)
278       dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
279   });
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // Legalization Driver.
284 //===----------------------------------------------------------------------===//
285 
286 namespace {
287 //===----------------------------------------------------------------------===//
288 // Region-Block Rewrite Testing
289 
290 /// This pattern is a simple pattern that inlines the first region of a given
291 /// operation into the parent region.
292 struct TestRegionRewriteBlockMovement : public ConversionPattern {
293   TestRegionRewriteBlockMovement(MLIRContext *ctx)
294       : ConversionPattern("test.region", 1, ctx) {}
295 
296   LogicalResult
297   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
298                   ConversionPatternRewriter &rewriter) const final {
299     // Inline this region into the parent region.
300     auto &parentRegion = *op->getParentRegion();
301     auto &opRegion = op->getRegion(0);
302     if (op->getAttr("legalizer.should_clone"))
303       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
304     else
305       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
306 
307     if (op->getAttr("legalizer.erase_old_blocks")) {
308       while (!opRegion.empty())
309         rewriter.eraseBlock(&opRegion.front());
310     }
311 
312     // Drop this operation.
313     rewriter.eraseOp(op);
314     return success();
315   }
316 };
317 /// This pattern is a simple pattern that generates a region containing an
318 /// illegal operation.
319 struct TestRegionRewriteUndo : public RewritePattern {
320   TestRegionRewriteUndo(MLIRContext *ctx)
321       : RewritePattern("test.region_builder", 1, ctx) {}
322 
323   LogicalResult matchAndRewrite(Operation *op,
324                                 PatternRewriter &rewriter) const final {
325     // Create the region operation with an entry block containing arguments.
326     OperationState newRegion(op->getLoc(), "test.region");
327     newRegion.addRegion();
328     auto *regionOp = rewriter.create(newRegion);
329     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
330     entryBlock->addArgument(rewriter.getIntegerType(64),
331                             rewriter.getUnknownLoc());
332 
333     // Add an explicitly illegal operation to ensure the conversion fails.
334     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
335     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
336 
337     // Drop this operation.
338     rewriter.eraseOp(op);
339     return success();
340   }
341 };
342 /// A simple pattern that creates a block at the end of the parent region of the
343 /// matched operation.
344 struct TestCreateBlock : public RewritePattern {
345   TestCreateBlock(MLIRContext *ctx)
346       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
347 
348   LogicalResult matchAndRewrite(Operation *op,
349                                 PatternRewriter &rewriter) const final {
350     Region &region = *op->getParentRegion();
351     Type i32Type = rewriter.getIntegerType(32);
352     Location loc = op->getLoc();
353     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
354     rewriter.create<TerminatorOp>(loc);
355     rewriter.replaceOp(op, {});
356     return success();
357   }
358 };
359 
360 /// A simple pattern that creates a block containing an invalid operation in
361 /// order to trigger the block creation undo mechanism.
362 struct TestCreateIllegalBlock : public RewritePattern {
363   TestCreateIllegalBlock(MLIRContext *ctx)
364       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
365 
366   LogicalResult matchAndRewrite(Operation *op,
367                                 PatternRewriter &rewriter) const final {
368     Region &region = *op->getParentRegion();
369     Type i32Type = rewriter.getIntegerType(32);
370     Location loc = op->getLoc();
371     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
372     // Create an illegal op to ensure the conversion fails.
373     rewriter.create<ILLegalOpF>(loc, i32Type);
374     rewriter.create<TerminatorOp>(loc);
375     rewriter.replaceOp(op, {});
376     return success();
377   }
378 };
379 
380 /// A simple pattern that tests the undo mechanism when replacing the uses of a
381 /// block argument.
382 struct TestUndoBlockArgReplace : public ConversionPattern {
383   TestUndoBlockArgReplace(MLIRContext *ctx)
384       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
385 
386   LogicalResult
387   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
388                   ConversionPatternRewriter &rewriter) const final {
389     auto illegalOp =
390         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
391     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
392                                         illegalOp);
393     rewriter.updateRootInPlace(op, [] {});
394     return success();
395   }
396 };
397 
398 /// A rewrite pattern that tests the undo mechanism when erasing a block.
399 struct TestUndoBlockErase : public ConversionPattern {
400   TestUndoBlockErase(MLIRContext *ctx)
401       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
402 
403   LogicalResult
404   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
405                   ConversionPatternRewriter &rewriter) const final {
406     Block *secondBlock = &*std::next(op->getRegion(0).begin());
407     rewriter.setInsertionPointToStart(secondBlock);
408     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
409     rewriter.eraseBlock(secondBlock);
410     rewriter.updateRootInPlace(op, [] {});
411     return success();
412   }
413 };
414 
415 //===----------------------------------------------------------------------===//
416 // Type-Conversion Rewrite Testing
417 
418 /// This patterns erases a region operation that has had a type conversion.
419 struct TestDropOpSignatureConversion : public ConversionPattern {
420   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
421       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
422   LogicalResult
423   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
424                   ConversionPatternRewriter &rewriter) const override {
425     Region &region = op->getRegion(0);
426     Block *entry = &region.front();
427 
428     // Convert the original entry arguments.
429     TypeConverter &converter = *getTypeConverter();
430     TypeConverter::SignatureConversion result(entry->getNumArguments());
431     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
432                                               result)) ||
433         failed(rewriter.convertRegionTypes(&region, converter, &result)))
434       return failure();
435 
436     // Convert the region signature and just drop the operation.
437     rewriter.eraseOp(op);
438     return success();
439   }
440 };
441 /// This pattern simply updates the operands of the given operation.
442 struct TestPassthroughInvalidOp : public ConversionPattern {
443   TestPassthroughInvalidOp(MLIRContext *ctx)
444       : ConversionPattern("test.invalid", 1, ctx) {}
445   LogicalResult
446   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
447                   ConversionPatternRewriter &rewriter) const final {
448     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
449                                              llvm::None);
450     return success();
451   }
452 };
453 /// This pattern handles the case of a split return value.
454 struct TestSplitReturnType : public ConversionPattern {
455   TestSplitReturnType(MLIRContext *ctx)
456       : ConversionPattern("test.return", 1, ctx) {}
457   LogicalResult
458   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
459                   ConversionPatternRewriter &rewriter) const final {
460     // Check for a return of F32.
461     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
462       return failure();
463 
464     // Check if the first operation is a cast operation, if it is we use the
465     // results directly.
466     auto *defOp = operands[0].getDefiningOp();
467     if (auto packerOp =
468             llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
469       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
470       return success();
471     }
472 
473     // Otherwise, fail to match.
474     return failure();
475   }
476 };
477 
478 //===----------------------------------------------------------------------===//
479 // Multi-Level Type-Conversion Rewrite Testing
480 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
481   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
482       : ConversionPattern("test.type_producer", 1, ctx) {}
483   LogicalResult
484   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
485                   ConversionPatternRewriter &rewriter) const final {
486     // If the type is I32, change the type to F32.
487     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
488       return failure();
489     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
490     return success();
491   }
492 };
493 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
494   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
495       : ConversionPattern("test.type_producer", 1, ctx) {}
496   LogicalResult
497   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
498                   ConversionPatternRewriter &rewriter) const final {
499     // If the type is F32, change the type to F64.
500     if (!Type(*op->result_type_begin()).isF32())
501       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
502     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
503     return success();
504   }
505 };
506 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
507   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
508       : ConversionPattern("test.type_producer", 10, ctx) {}
509   LogicalResult
510   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
511                   ConversionPatternRewriter &rewriter) const final {
512     // Always convert to B16, even though it is not a legal type. This tests
513     // that values are unmapped correctly.
514     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
515     return success();
516   }
517 };
518 struct TestUpdateConsumerType : public ConversionPattern {
519   TestUpdateConsumerType(MLIRContext *ctx)
520       : ConversionPattern("test.type_consumer", 1, ctx) {}
521   LogicalResult
522   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
523                   ConversionPatternRewriter &rewriter) const final {
524     // Verify that the incoming operand has been successfully remapped to F64.
525     if (!operands[0].getType().isF64())
526       return failure();
527     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
528     return success();
529   }
530 };
531 
532 //===----------------------------------------------------------------------===//
533 // Non-Root Replacement Rewrite Testing
534 /// This pattern generates an invalid operation, but replaces it before the
535 /// pattern is finished. This checks that we don't need to legalize the
536 /// temporary op.
537 struct TestNonRootReplacement : public RewritePattern {
538   TestNonRootReplacement(MLIRContext *ctx)
539       : RewritePattern("test.replace_non_root", 1, ctx) {}
540 
541   LogicalResult matchAndRewrite(Operation *op,
542                                 PatternRewriter &rewriter) const final {
543     auto resultType = *op->result_type_begin();
544     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
545     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
546 
547     rewriter.replaceOp(illegalOp, {legalOp});
548     rewriter.replaceOp(op, {illegalOp});
549     return success();
550   }
551 };
552 
553 //===----------------------------------------------------------------------===//
554 // Recursive Rewrite Testing
555 /// This pattern is applied to the same operation multiple times, but has a
556 /// bounded recursion.
557 struct TestBoundedRecursiveRewrite
558     : public OpRewritePattern<TestRecursiveRewriteOp> {
559   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
560 
561   void initialize() {
562     // The conversion target handles bounding the recursion of this pattern.
563     setHasBoundedRewriteRecursion();
564   }
565 
566   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
567                                 PatternRewriter &rewriter) const final {
568     // Decrement the depth of the op in-place.
569     rewriter.updateRootInPlace(op, [&] {
570       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
571     });
572     return success();
573   }
574 };
575 
576 struct TestNestedOpCreationUndoRewrite
577     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
578   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
579 
580   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
581                                 PatternRewriter &rewriter) const final {
582     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
583     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
584     return success();
585   };
586 };
587 
588 // This pattern matches `test.blackhole` and delete this op and its producer.
589 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
590   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
591 
592   LogicalResult matchAndRewrite(BlackHoleOp op,
593                                 PatternRewriter &rewriter) const final {
594     Operation *producer = op.getOperand().getDefiningOp();
595     // Always erase the user before the producer, the framework should handle
596     // this correctly.
597     rewriter.eraseOp(op);
598     rewriter.eraseOp(producer);
599     return success();
600   };
601 };
602 
603 // This pattern replaces explicitly illegal op with explicitly legal op,
604 // but in addition creates unregistered operation.
605 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
606   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
607 
608   LogicalResult matchAndRewrite(ILLegalOpG op,
609                                 PatternRewriter &rewriter) const final {
610     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
611     Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
612     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
613     return success();
614   };
615 };
616 } // namespace
617 
618 namespace {
619 struct TestTypeConverter : public TypeConverter {
620   using TypeConverter::TypeConverter;
621   TestTypeConverter() {
622     addConversion(convertType);
623     addArgumentMaterialization(materializeCast);
624     addSourceMaterialization(materializeCast);
625   }
626 
627   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
628     // Drop I16 types.
629     if (t.isSignlessInteger(16))
630       return success();
631 
632     // Convert I64 to F64.
633     if (t.isSignlessInteger(64)) {
634       results.push_back(FloatType::getF64(t.getContext()));
635       return success();
636     }
637 
638     // Convert I42 to I43.
639     if (t.isInteger(42)) {
640       results.push_back(IntegerType::get(t.getContext(), 43));
641       return success();
642     }
643 
644     // Split F32 into F16,F16.
645     if (t.isF32()) {
646       results.assign(2, FloatType::getF16(t.getContext()));
647       return success();
648     }
649 
650     // Otherwise, convert the type directly.
651     results.push_back(t);
652     return success();
653   }
654 
655   /// Hook for materializing a conversion. This is necessary because we generate
656   /// 1->N type mappings.
657   static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
658                                          ValueRange inputs, Location loc) {
659     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
660   }
661 };
662 
663 struct TestLegalizePatternDriver
664     : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
665   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
666 
667   StringRef getArgument() const final { return "test-legalize-patterns"; }
668   StringRef getDescription() const final {
669     return "Run test dialect legalization patterns";
670   }
671   /// The mode of conversion to use with the driver.
672   enum class ConversionMode { Analysis, Full, Partial };
673 
674   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
675 
676   void getDependentDialects(DialectRegistry &registry) const override {
677     registry.insert<func::FuncDialect>();
678   }
679 
680   void runOnOperation() override {
681     TestTypeConverter converter;
682     mlir::RewritePatternSet patterns(&getContext());
683     populateWithGenerated(patterns);
684     patterns
685         .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
686              TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace,
687              TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType,
688              TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
689              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
690              TestNonRootReplacement, TestBoundedRecursiveRewrite,
691              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
692              TestCreateUnregisteredOp>(&getContext());
693     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
694     mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
695         patterns, converter);
696     mlir::populateCallOpTypeConversionPattern(patterns, converter);
697 
698     // Define the conversion target used for the test.
699     ConversionTarget target(getContext());
700     target.addLegalOp<ModuleOp>();
701     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
702                       TerminatorOp>();
703     target
704         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
705     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
706       // Don't allow F32 operands.
707       return llvm::none_of(op.getOperandTypes(),
708                            [](Type type) { return type.isF32(); });
709     });
710     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
711       return converter.isSignatureLegal(op.getFunctionType()) &&
712              converter.isLegal(&op.getBody());
713     });
714     target.addDynamicallyLegalOp<func::CallOp>(
715         [&](func::CallOp op) { return converter.isLegal(op); });
716 
717     // TestCreateUnregisteredOp creates `arith.constant` operation,
718     // which was not added to target intentionally to test
719     // correct error code from conversion driver.
720     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
721 
722     // Expect the type_producer/type_consumer operations to only operate on f64.
723     target.addDynamicallyLegalOp<TestTypeProducerOp>(
724         [](TestTypeProducerOp op) { return op.getType().isF64(); });
725     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
726       return op.getOperand().getType().isF64();
727     });
728 
729     // Check support for marking certain operations as recursively legal.
730     target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
731       return static_cast<bool>(
732           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
733     });
734 
735     // Mark the bound recursion operation as dynamically legal.
736     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
737         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
738 
739     // Handle a partial conversion.
740     if (mode == ConversionMode::Partial) {
741       DenseSet<Operation *> unlegalizedOps;
742       if (failed(applyPartialConversion(
743               getOperation(), target, std::move(patterns), &unlegalizedOps))) {
744         getOperation()->emitRemark() << "applyPartialConversion failed";
745       }
746       // Emit remarks for each legalizable operation.
747       for (auto *op : unlegalizedOps)
748         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
749       return;
750     }
751 
752     // Handle a full conversion.
753     if (mode == ConversionMode::Full) {
754       // Check support for marking unknown operations as dynamically legal.
755       target.markUnknownOpDynamicallyLegal([](Operation *op) {
756         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
757       });
758 
759       if (failed(applyFullConversion(getOperation(), target,
760                                      std::move(patterns)))) {
761         getOperation()->emitRemark() << "applyFullConversion failed";
762       }
763       return;
764     }
765 
766     // Otherwise, handle an analysis conversion.
767     assert(mode == ConversionMode::Analysis);
768 
769     // Analyze the convertible operations.
770     DenseSet<Operation *> legalizedOps;
771     if (failed(applyAnalysisConversion(getOperation(), target,
772                                        std::move(patterns), legalizedOps)))
773       return signalPassFailure();
774 
775     // Emit remarks for each legalizable operation.
776     for (auto *op : legalizedOps)
777       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
778   }
779 
780   /// The mode of conversion to use.
781   ConversionMode mode;
782 };
783 } // namespace
784 
785 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
786     legalizerConversionMode(
787         "test-legalize-mode",
788         llvm::cl::desc("The legalization mode to use with the test driver"),
789         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
790         llvm::cl::values(
791             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
792                        "analysis", "Perform an analysis conversion"),
793             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
794                        "Perform a full conversion"),
795             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
796                        "partial", "Perform a partial conversion")));
797 
798 //===----------------------------------------------------------------------===//
799 // ConversionPatternRewriter::getRemappedValue testing. This method is used
800 // to get the remapped value of an original value that was replaced using
801 // ConversionPatternRewriter.
802 namespace {
803 struct TestRemapValueTypeConverter : public TypeConverter {
804   using TypeConverter::TypeConverter;
805 
806   TestRemapValueTypeConverter() {
807     addConversion(
808         [](Float32Type type) { return Float64Type::get(type.getContext()); });
809     addConversion([](Type type) { return type; });
810   }
811 };
812 
813 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
814 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
815 /// operand twice.
816 ///
817 /// Example:
818 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
819 /// is replaced with:
820 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
821 struct OneVResOneVOperandOp1Converter
822     : public OpConversionPattern<OneVResOneVOperandOp1> {
823   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
824 
825   LogicalResult
826   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
827                   ConversionPatternRewriter &rewriter) const override {
828     auto origOps = op.getOperands();
829     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
830            "One operand expected");
831     Value origOp = *origOps.begin();
832     SmallVector<Value, 2> remappedOperands;
833     // Replicate the remapped original operand twice. Note that we don't used
834     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
835     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
836     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
837 
838     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
839                                                        remappedOperands);
840     return success();
841   }
842 };
843 
844 /// A rewriter pattern that tests that blocks can be merged.
845 struct TestRemapValueInRegion
846     : public OpConversionPattern<TestRemappedValueRegionOp> {
847   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
848 
849   LogicalResult
850   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
851                   ConversionPatternRewriter &rewriter) const final {
852     Block &block = op.getBody().front();
853     Operation *terminator = block.getTerminator();
854 
855     // Merge the block into the parent region.
856     Block *parentBlock = op->getBlock();
857     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
858     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
859     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
860 
861     // Replace the results of this operation with the remapped terminator
862     // values.
863     SmallVector<Value> terminatorOperands;
864     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
865                                           terminatorOperands)))
866       return failure();
867 
868     rewriter.eraseOp(terminator);
869     rewriter.replaceOp(op, terminatorOperands);
870     return success();
871   }
872 };
873 
874 struct TestRemappedValue
875     : public mlir::PassWrapper<TestRemappedValue, OperationPass<func::FuncOp>> {
876   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
877 
878   StringRef getArgument() const final { return "test-remapped-value"; }
879   StringRef getDescription() const final {
880     return "Test public remapped value mechanism in ConversionPatternRewriter";
881   }
882   void runOnOperation() override {
883     TestRemapValueTypeConverter typeConverter;
884 
885     mlir::RewritePatternSet patterns(&getContext());
886     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
887     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
888         &getContext());
889     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
890 
891     mlir::ConversionTarget target(getContext());
892     target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
893 
894     // Expect the type_producer/type_consumer operations to only operate on f64.
895     target.addDynamicallyLegalOp<TestTypeProducerOp>(
896         [](TestTypeProducerOp op) { return op.getType().isF64(); });
897     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
898       return op.getOperand().getType().isF64();
899     });
900 
901     // We make OneVResOneVOperandOp1 legal only when it has more that one
902     // operand. This will trigger the conversion that will replace one-operand
903     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
904     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
905         [](Operation *op) { return op->getNumOperands() > 1; });
906 
907     if (failed(mlir::applyFullConversion(getOperation(), target,
908                                          std::move(patterns)))) {
909       signalPassFailure();
910     }
911   }
912 };
913 } // namespace
914 
915 //===----------------------------------------------------------------------===//
916 // Test patterns without a specific root operation kind
917 //===----------------------------------------------------------------------===//
918 
919 namespace {
920 /// This pattern matches and removes any operation in the test dialect.
921 struct RemoveTestDialectOps : public RewritePattern {
922   RemoveTestDialectOps(MLIRContext *context)
923       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
924 
925   LogicalResult matchAndRewrite(Operation *op,
926                                 PatternRewriter &rewriter) const override {
927     if (!isa<TestDialect>(op->getDialect()))
928       return failure();
929     rewriter.eraseOp(op);
930     return success();
931   }
932 };
933 
934 struct TestUnknownRootOpDriver
935     : public mlir::PassWrapper<TestUnknownRootOpDriver,
936                                OperationPass<func::FuncOp>> {
937   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
938 
939   StringRef getArgument() const final {
940     return "test-legalize-unknown-root-patterns";
941   }
942   StringRef getDescription() const final {
943     return "Test public remapped value mechanism in ConversionPatternRewriter";
944   }
945   void runOnOperation() override {
946     mlir::RewritePatternSet patterns(&getContext());
947     patterns.add<RemoveTestDialectOps>(&getContext());
948 
949     mlir::ConversionTarget target(getContext());
950     target.addIllegalDialect<TestDialect>();
951     if (failed(applyPartialConversion(getOperation(), target,
952                                       std::move(patterns))))
953       signalPassFailure();
954   }
955 };
956 } // namespace
957 
958 //===----------------------------------------------------------------------===//
959 // Test type conversions
960 //===----------------------------------------------------------------------===//
961 
962 namespace {
963 struct TestTypeConversionProducer
964     : public OpConversionPattern<TestTypeProducerOp> {
965   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
966   LogicalResult
967   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
968                   ConversionPatternRewriter &rewriter) const final {
969     Type resultType = op.getType();
970     Type convertedType = getTypeConverter()
971                              ? getTypeConverter()->convertType(resultType)
972                              : resultType;
973     if (resultType.isa<FloatType>())
974       resultType = rewriter.getF64Type();
975     else if (resultType.isInteger(16))
976       resultType = rewriter.getIntegerType(64);
977     else if (resultType.isa<test::TestRecursiveType>() &&
978              convertedType != resultType)
979       resultType = convertedType;
980     else
981       return failure();
982 
983     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
984     return success();
985   }
986 };
987 
988 /// Call signature conversion and then fail the rewrite to trigger the undo
989 /// mechanism.
990 struct TestSignatureConversionUndo
991     : public OpConversionPattern<TestSignatureConversionUndoOp> {
992   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
993 
994   LogicalResult
995   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
996                   ConversionPatternRewriter &rewriter) const final {
997     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
998     return failure();
999   }
1000 };
1001 
1002 /// Call signature conversion without providing a type converter to handle
1003 /// materializations.
1004 struct TestTestSignatureConversionNoConverter
1005     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1006   TestTestSignatureConversionNoConverter(TypeConverter &converter,
1007                                          MLIRContext *context)
1008       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1009         converter(converter) {}
1010 
1011   LogicalResult
1012   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1013                   ConversionPatternRewriter &rewriter) const final {
1014     Region &region = op->getRegion(0);
1015     Block *entry = &region.front();
1016 
1017     // Convert the original entry arguments.
1018     TypeConverter::SignatureConversion result(entry->getNumArguments());
1019     if (failed(
1020             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
1021       return failure();
1022     rewriter.updateRootInPlace(
1023         op, [&] { rewriter.applySignatureConversion(&region, result); });
1024     return success();
1025   }
1026 
1027   TypeConverter &converter;
1028 };
1029 
1030 /// Just forward the operands to the root op. This is essentially a no-op
1031 /// pattern that is used to trigger target materialization.
1032 struct TestTypeConsumerForward
1033     : public OpConversionPattern<TestTypeConsumerOp> {
1034   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1035 
1036   LogicalResult
1037   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1038                   ConversionPatternRewriter &rewriter) const final {
1039     rewriter.updateRootInPlace(op,
1040                                [&] { op->setOperands(adaptor.getOperands()); });
1041     return success();
1042   }
1043 };
1044 
1045 struct TestTypeConversionAnotherProducer
1046     : public OpRewritePattern<TestAnotherTypeProducerOp> {
1047   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1048 
1049   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1050                                 PatternRewriter &rewriter) const final {
1051     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1052     return success();
1053   }
1054 };
1055 
1056 struct TestTypeConversionDriver
1057     : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
1058   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1059 
1060   void getDependentDialects(DialectRegistry &registry) const override {
1061     registry.insert<TestDialect>();
1062   }
1063   StringRef getArgument() const final {
1064     return "test-legalize-type-conversion";
1065   }
1066   StringRef getDescription() const final {
1067     return "Test various type conversion functionalities in DialectConversion";
1068   }
1069 
1070   void runOnOperation() override {
1071     // Initialize the type converter.
1072     TypeConverter converter;
1073 
1074     /// Add the legal set of type conversions.
1075     converter.addConversion([](Type type) -> Type {
1076       // Treat F64 as legal.
1077       if (type.isF64())
1078         return type;
1079       // Allow converting BF16/F16/F32 to F64.
1080       if (type.isBF16() || type.isF16() || type.isF32())
1081         return FloatType::getF64(type.getContext());
1082       // Otherwise, the type is illegal.
1083       return nullptr;
1084     });
1085     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1086       // Drop all integer types.
1087       return success();
1088     });
1089     converter.addConversion(
1090         // Convert a recursive self-referring type into a non-self-referring
1091         // type named "outer_converted_type" that contains a SimpleAType.
1092         [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results,
1093             ArrayRef<Type> callStack) -> Optional<LogicalResult> {
1094           // If the type is already converted, return it to indicate that it is
1095           // legal.
1096           if (type.getName() == "outer_converted_type") {
1097             results.push_back(type);
1098             return success();
1099           }
1100 
1101           // If the type is on the call stack more than once (it is there at
1102           // least once because of the _current_ call, which is always the last
1103           // element on the stack), we've hit the recursive case. Just return
1104           // SimpleAType here to create a non-recursive type as a result.
1105           if (llvm::is_contained(callStack.drop_back(), type)) {
1106             results.push_back(test::SimpleAType::get(type.getContext()));
1107             return success();
1108           }
1109 
1110           // Convert the body recursively.
1111           auto result = test::TestRecursiveType::get(type.getContext(),
1112                                                      "outer_converted_type");
1113           if (failed(result.setBody(converter.convertType(type.getBody()))))
1114             return failure();
1115           results.push_back(result);
1116           return success();
1117         });
1118 
1119     /// Add the legal set of type materializations.
1120     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1121                                           ValueRange inputs,
1122                                           Location loc) -> Value {
1123       // Allow casting from F64 back to F32.
1124       if (!resultType.isF16() && inputs.size() == 1 &&
1125           inputs[0].getType().isF64())
1126         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1127       // Allow producing an i32 or i64 from nothing.
1128       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1129           inputs.empty())
1130         return builder.create<TestTypeProducerOp>(loc, resultType);
1131       // Allow producing an i64 from an integer.
1132       if (resultType.isa<IntegerType>() && inputs.size() == 1 &&
1133           inputs[0].getType().isa<IntegerType>())
1134         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1135       // Otherwise, fail.
1136       return nullptr;
1137     });
1138 
1139     // Initialize the conversion target.
1140     mlir::ConversionTarget target(getContext());
1141     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1142       auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>();
1143       return op.getType().isF64() || op.getType().isInteger(64) ||
1144              (recursiveType &&
1145               recursiveType.getName() == "outer_converted_type");
1146     });
1147     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1148       return converter.isSignatureLegal(op.getFunctionType()) &&
1149              converter.isLegal(&op.getBody());
1150     });
1151     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1152       // Allow casts from F64 to F32.
1153       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1154     });
1155     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1156         [&](TestSignatureConversionNoConverterOp op) {
1157           return converter.isLegal(op.getRegion().front().getArgumentTypes());
1158         });
1159 
1160     // Initialize the set of rewrite patterns.
1161     RewritePatternSet patterns(&getContext());
1162     patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1163                  TestSignatureConversionUndo,
1164                  TestTestSignatureConversionNoConverter>(converter,
1165                                                          &getContext());
1166     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
1167     mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
1168         patterns, converter);
1169 
1170     if (failed(applyPartialConversion(getOperation(), target,
1171                                       std::move(patterns))))
1172       signalPassFailure();
1173   }
1174 };
1175 } // namespace
1176 
1177 //===----------------------------------------------------------------------===//
1178 // Test Target Materialization With No Uses
1179 //===----------------------------------------------------------------------===//
1180 
1181 namespace {
1182 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1183   using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1184 
1185   LogicalResult
1186   matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1187                   ConversionPatternRewriter &rewriter) const final {
1188     rewriter.replaceOp(op, adaptor.getOperands());
1189     return success();
1190   }
1191 };
1192 
1193 struct TestTargetMaterializationWithNoUses
1194     : public PassWrapper<TestTargetMaterializationWithNoUses,
1195                          OperationPass<ModuleOp>> {
1196   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1197       TestTargetMaterializationWithNoUses)
1198 
1199   StringRef getArgument() const final {
1200     return "test-target-materialization-with-no-uses";
1201   }
1202   StringRef getDescription() const final {
1203     return "Test a special case of target materialization in DialectConversion";
1204   }
1205 
1206   void runOnOperation() override {
1207     TypeConverter converter;
1208     converter.addConversion([](Type t) { return t; });
1209     converter.addConversion([](IntegerType intTy) -> Type {
1210       if (intTy.getWidth() == 16)
1211         return IntegerType::get(intTy.getContext(), 64);
1212       return intTy;
1213     });
1214     converter.addTargetMaterialization(
1215         [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1216           return builder.create<TestCastOp>(loc, type, inputs).getResult();
1217         });
1218 
1219     ConversionTarget target(getContext());
1220     target.addIllegalOp<TestTypeChangerOp>();
1221 
1222     RewritePatternSet patterns(&getContext());
1223     patterns.add<ForwardOperandPattern>(converter, &getContext());
1224 
1225     if (failed(applyPartialConversion(getOperation(), target,
1226                                       std::move(patterns))))
1227       signalPassFailure();
1228   }
1229 };
1230 } // namespace
1231 
1232 //===----------------------------------------------------------------------===//
1233 // Test Block Merging
1234 //===----------------------------------------------------------------------===//
1235 
1236 namespace {
1237 /// A rewriter pattern that tests that blocks can be merged.
1238 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1239   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1240 
1241   LogicalResult
1242   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1243                   ConversionPatternRewriter &rewriter) const final {
1244     Block &firstBlock = op.getBody().front();
1245     Operation *branchOp = firstBlock.getTerminator();
1246     Block *secondBlock = &*(std::next(op.getBody().begin()));
1247     auto succOperands = branchOp->getOperands();
1248     SmallVector<Value, 2> replacements(succOperands);
1249     rewriter.eraseOp(branchOp);
1250     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1251     rewriter.updateRootInPlace(op, [] {});
1252     return success();
1253   }
1254 };
1255 
1256 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1257 struct TestUndoBlocksMerge : public ConversionPattern {
1258   TestUndoBlocksMerge(MLIRContext *ctx)
1259       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1260   LogicalResult
1261   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1262                   ConversionPatternRewriter &rewriter) const final {
1263     Block &firstBlock = op->getRegion(0).front();
1264     Operation *branchOp = firstBlock.getTerminator();
1265     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1266     rewriter.setInsertionPointToStart(secondBlock);
1267     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1268     auto succOperands = branchOp->getOperands();
1269     SmallVector<Value, 2> replacements(succOperands);
1270     rewriter.eraseOp(branchOp);
1271     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1272     rewriter.updateRootInPlace(op, [] {});
1273     return success();
1274   }
1275 };
1276 
1277 /// A rewrite mechanism to inline the body of the op into its parent, when both
1278 /// ops can have a single block.
1279 struct TestMergeSingleBlockOps
1280     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1281   using OpConversionPattern<
1282       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1283 
1284   LogicalResult
1285   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1286                   ConversionPatternRewriter &rewriter) const final {
1287     SingleBlockImplicitTerminatorOp parentOp =
1288         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1289     if (!parentOp)
1290       return failure();
1291     Block &innerBlock = op.getRegion().front();
1292     TerminatorOp innerTerminator =
1293         cast<TerminatorOp>(innerBlock.getTerminator());
1294     rewriter.mergeBlockBefore(&innerBlock, op);
1295     rewriter.eraseOp(innerTerminator);
1296     rewriter.eraseOp(op);
1297     rewriter.updateRootInPlace(op, [] {});
1298     return success();
1299   }
1300 };
1301 
1302 struct TestMergeBlocksPatternDriver
1303     : public PassWrapper<TestMergeBlocksPatternDriver,
1304                          OperationPass<ModuleOp>> {
1305   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
1306 
1307   StringRef getArgument() const final { return "test-merge-blocks"; }
1308   StringRef getDescription() const final {
1309     return "Test Merging operation in ConversionPatternRewriter";
1310   }
1311   void runOnOperation() override {
1312     MLIRContext *context = &getContext();
1313     mlir::RewritePatternSet patterns(context);
1314     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1315         context);
1316     ConversionTarget target(*context);
1317     target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1318                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1319     target.addIllegalOp<ILLegalOpF>();
1320 
1321     /// Expect the op to have a single block after legalization.
1322     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1323         [&](TestMergeBlocksOp op) -> bool {
1324           return llvm::hasSingleElement(op.getBody());
1325         });
1326 
1327     /// Only allow `test.br` within test.merge_blocks op.
1328     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
1329       return op->getParentOfType<TestMergeBlocksOp>();
1330     });
1331 
1332     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1333     /// inlined.
1334     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1335         [&](SingleBlockImplicitTerminatorOp op) -> bool {
1336           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1337         });
1338 
1339     DenseSet<Operation *> unlegalizedOps;
1340     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1341                                  &unlegalizedOps);
1342     for (auto *op : unlegalizedOps)
1343       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1344   }
1345 };
1346 } // namespace
1347 
1348 //===----------------------------------------------------------------------===//
1349 // Test Selective Replacement
1350 //===----------------------------------------------------------------------===//
1351 
1352 namespace {
1353 /// A rewrite mechanism to inline the body of the op into its parent, when both
1354 /// ops can have a single block.
1355 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
1356   using OpRewritePattern<TestCastOp>::OpRewritePattern;
1357 
1358   LogicalResult matchAndRewrite(TestCastOp op,
1359                                 PatternRewriter &rewriter) const final {
1360     if (op.getNumOperands() != 2)
1361       return failure();
1362     OperandRange operands = op.getOperands();
1363 
1364     // Replace non-terminator uses with the first operand.
1365     rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
1366       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
1367     });
1368     // Replace everything else with the second operand if the operation isn't
1369     // dead.
1370     rewriter.replaceOp(op, op.getOperand(1));
1371     return success();
1372   }
1373 };
1374 
1375 struct TestSelectiveReplacementPatternDriver
1376     : public PassWrapper<TestSelectiveReplacementPatternDriver,
1377                          OperationPass<>> {
1378   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1379       TestSelectiveReplacementPatternDriver)
1380 
1381   StringRef getArgument() const final {
1382     return "test-pattern-selective-replacement";
1383   }
1384   StringRef getDescription() const final {
1385     return "Test selective replacement in the PatternRewriter";
1386   }
1387   void runOnOperation() override {
1388     MLIRContext *context = &getContext();
1389     mlir::RewritePatternSet patterns(context);
1390     patterns.add<TestSelectiveOpReplacementPattern>(context);
1391     (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
1392                                        std::move(patterns));
1393   }
1394 };
1395 } // namespace
1396 
1397 //===----------------------------------------------------------------------===//
1398 // PassRegistration
1399 //===----------------------------------------------------------------------===//
1400 
1401 namespace mlir {
1402 namespace test {
1403 void registerPatternsTestPass() {
1404   PassRegistration<TestReturnTypeDriver>();
1405 
1406   PassRegistration<TestDerivedAttributeDriver>();
1407 
1408   PassRegistration<TestPatternDriver>();
1409 
1410   PassRegistration<TestLegalizePatternDriver>([] {
1411     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
1412   });
1413 
1414   PassRegistration<TestRemappedValue>();
1415 
1416   PassRegistration<TestUnknownRootOpDriver>();
1417 
1418   PassRegistration<TestTypeConversionDriver>();
1419   PassRegistration<TestTargetMaterializationWithNoUses>();
1420 
1421   PassRegistration<TestMergeBlocksPatternDriver>();
1422   PassRegistration<TestSelectiveReplacementPatternDriver>();
1423 }
1424 } // namespace test
1425 } // namespace mlir
1426