1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "mlir/Transforms/FoldUtils.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace test;
23 
24 // Native function for testing NativeCodeCall
25 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
26   return choice.getValue() ? input1 : input2;
27 }
28 
29 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
30   rewriter.create<OpI>(loc, input);
31 }
32 
33 static void handleNoResultOp(PatternRewriter &rewriter,
34                              OpSymbolBindingNoResult op) {
35   // Turn the no result op to a one-result op.
36   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
37                                     op.getOperand());
38 }
39 
40 static bool getFirstI32Result(Operation *op, Value &value) {
41   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
42     return false;
43   value = op->getResult(0);
44   return true;
45 }
46 
47 static Value bindNativeCodeCallResult(Value value) { return value; }
48 
49 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
50                                                               Value input2) {
51   return SmallVector<Value, 2>({input2, input1});
52 }
53 
54 // Test that natives calls are only called once during rewrites.
55 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
56 // This let us check the number of times OpM_Test was called by inspecting
57 // the returned value in the MLIR output.
58 static int64_t opMIncreasingValue = 314159265;
59 static Attribute opMTest(PatternRewriter &rewriter, Value val) {
60   int64_t i = opMIncreasingValue++;
61   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
62 }
63 
64 namespace {
65 #include "TestPatterns.inc"
66 } // namespace
67 
68 //===----------------------------------------------------------------------===//
69 // Test Reduce Pattern Interface
70 //===----------------------------------------------------------------------===//
71 
72 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
73   populateWithGenerated(patterns);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Canonicalizer Driver.
78 //===----------------------------------------------------------------------===//
79 
80 namespace {
81 struct FoldingPattern : public RewritePattern {
82 public:
83   FoldingPattern(MLIRContext *context)
84       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
85                        /*benefit=*/1, context) {}
86 
87   LogicalResult matchAndRewrite(Operation *op,
88                                 PatternRewriter &rewriter) const override {
89     // Exercise OperationFolder API for a single-result operation that is folded
90     // upon construction. The operation being created through the folder has an
91     // in-place folder, and it should be still present in the output.
92     // Furthermore, the folder should not crash when attempting to recover the
93     // (unchanged) operation result.
94     OperationFolder folder(op->getContext());
95     Value result = folder.create<TestOpInPlaceFold>(
96         rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
97         rewriter.getI32IntegerAttr(0));
98     assert(result);
99     rewriter.replaceOp(op, result);
100     return success();
101   }
102 };
103 
104 /// This pattern creates a foldable operation at the entry point of the block.
105 /// This tests the situation where the operation folder will need to replace an
106 /// operation with a previously created constant that does not initially
107 /// dominate the operation to replace.
108 struct FolderInsertBeforePreviouslyFoldedConstantPattern
109     : public OpRewritePattern<TestCastOp> {
110 public:
111   using OpRewritePattern<TestCastOp>::OpRewritePattern;
112 
113   LogicalResult matchAndRewrite(TestCastOp op,
114                                 PatternRewriter &rewriter) const override {
115     if (!op->hasAttr("test_fold_before_previously_folded_op"))
116       return failure();
117     rewriter.setInsertionPointToStart(op->getBlock());
118 
119     auto constOp = rewriter.create<arith::ConstantOp>(
120         op.getLoc(), rewriter.getBoolAttr(true));
121     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
122                                             Value(constOp));
123     return success();
124   }
125 };
126 
127 /// This pattern matches test.op_commutative2 with the first operand being
128 /// another test.op_commutative2 with a constant on the right side and fold it
129 /// away by propagating it as its result. This is intend to check that patterns
130 /// are applied after the commutative property moves constant to the right.
131 struct FolderCommutativeOp2WithConstant
132     : public OpRewritePattern<TestCommutative2Op> {
133 public:
134   using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
135 
136   LogicalResult matchAndRewrite(TestCommutative2Op op,
137                                 PatternRewriter &rewriter) const override {
138     auto operand =
139         dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
140     if (!operand)
141       return failure();
142     Attribute constInput;
143     if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
144       return failure();
145     rewriter.replaceOp(op, operand->getOperand(1));
146     return success();
147   }
148 };
149 
150 struct TestPatternDriver
151     : public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> {
152   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
153 
154   StringRef getArgument() const final { return "test-patterns"; }
155   StringRef getDescription() const final { return "Run test dialect patterns"; }
156   void runOnOperation() override {
157     mlir::RewritePatternSet patterns(&getContext());
158     populateWithGenerated(patterns);
159 
160     // Verify named pattern is generated with expected name.
161     patterns.add<FoldingPattern, TestNamedPatternRule,
162                  FolderInsertBeforePreviouslyFoldedConstantPattern,
163                  FolderCommutativeOp2WithConstant>(&getContext());
164 
165     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
166   }
167 };
168 } // namespace
169 
170 //===----------------------------------------------------------------------===//
171 // ReturnType Driver.
172 //===----------------------------------------------------------------------===//
173 
174 namespace {
175 // Generate ops for each instance where the type can be successfully inferred.
176 template <typename OpTy>
177 static void invokeCreateWithInferredReturnType(Operation *op) {
178   auto *context = op->getContext();
179   auto fop = op->getParentOfType<func::FuncOp>();
180   auto location = UnknownLoc::get(context);
181   OpBuilder b(op);
182   b.setInsertionPointAfter(op);
183 
184   // Use permutations of 2 args as operands.
185   assert(fop.getNumArguments() >= 2);
186   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
187     for (int j = 0; j < e; ++j) {
188       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
189       SmallVector<Type, 2> inferredReturnTypes;
190       if (succeeded(OpTy::inferReturnTypes(
191               context, llvm::None, values, op->getAttrDictionary(),
192               op->getRegions(), inferredReturnTypes))) {
193         OperationState state(location, OpTy::getOperationName());
194         // TODO: Expand to regions.
195         OpTy::build(b, state, values, op->getAttrs());
196         (void)b.create(state);
197       }
198     }
199   }
200 }
201 
202 static void reifyReturnShape(Operation *op) {
203   OpBuilder b(op);
204 
205   // Use permutations of 2 args as operands.
206   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
207   SmallVector<Value, 2> shapes;
208   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
209       !llvm::hasSingleElement(shapes))
210     return;
211   for (const auto &it : llvm::enumerate(shapes)) {
212     op->emitRemark() << "value " << it.index() << ": "
213                      << it.value().getDefiningOp();
214   }
215 }
216 
217 struct TestReturnTypeDriver
218     : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
219   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
220 
221   void getDependentDialects(DialectRegistry &registry) const override {
222     registry.insert<tensor::TensorDialect>();
223   }
224   StringRef getArgument() const final { return "test-return-type"; }
225   StringRef getDescription() const final { return "Run return type functions"; }
226 
227   void runOnOperation() override {
228     if (getOperation().getName() == "testCreateFunctions") {
229       std::vector<Operation *> ops;
230       // Collect ops to avoid triggering on inserted ops.
231       for (auto &op : getOperation().getBody().front())
232         ops.push_back(&op);
233       // Generate test patterns for each, but skip terminator.
234       for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
235         // Test create method of each of the Op classes below. The resultant
236         // output would be in reverse order underneath `op` from which
237         // the attributes and regions are used.
238         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
239         invokeCreateWithInferredReturnType<
240             OpWithShapedTypeInferTypeInterfaceOp>(op);
241       };
242       return;
243     }
244     if (getOperation().getName() == "testReifyFunctions") {
245       std::vector<Operation *> ops;
246       // Collect ops to avoid triggering on inserted ops.
247       for (auto &op : getOperation().getBody().front())
248         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
249           ops.push_back(&op);
250       // Generate test patterns for each, but skip terminator.
251       for (auto *op : ops)
252         reifyReturnShape(op);
253     }
254   }
255 };
256 } // namespace
257 
258 namespace {
259 struct TestDerivedAttributeDriver
260     : public PassWrapper<TestDerivedAttributeDriver,
261                          OperationPass<func::FuncOp>> {
262   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
263 
264   StringRef getArgument() const final { return "test-derived-attr"; }
265   StringRef getDescription() const final {
266     return "Run test derived attributes";
267   }
268   void runOnOperation() override;
269 };
270 } // namespace
271 
272 void TestDerivedAttributeDriver::runOnOperation() {
273   getOperation().walk([](DerivedAttributeOpInterface dOp) {
274     auto dAttr = dOp.materializeDerivedAttributes();
275     if (!dAttr)
276       return;
277     for (auto d : dAttr)
278       dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
279   });
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // Legalization Driver.
284 //===----------------------------------------------------------------------===//
285 
286 namespace {
287 //===----------------------------------------------------------------------===//
288 // Region-Block Rewrite Testing
289 
290 /// This pattern is a simple pattern that inlines the first region of a given
291 /// operation into the parent region.
292 struct TestRegionRewriteBlockMovement : public ConversionPattern {
293   TestRegionRewriteBlockMovement(MLIRContext *ctx)
294       : ConversionPattern("test.region", 1, ctx) {}
295 
296   LogicalResult
297   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
298                   ConversionPatternRewriter &rewriter) const final {
299     // Inline this region into the parent region.
300     auto &parentRegion = *op->getParentRegion();
301     auto &opRegion = op->getRegion(0);
302     if (op->getAttr("legalizer.should_clone"))
303       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
304     else
305       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
306 
307     if (op->getAttr("legalizer.erase_old_blocks")) {
308       while (!opRegion.empty())
309         rewriter.eraseBlock(&opRegion.front());
310     }
311 
312     // Drop this operation.
313     rewriter.eraseOp(op);
314     return success();
315   }
316 };
317 /// This pattern is a simple pattern that generates a region containing an
318 /// illegal operation.
319 struct TestRegionRewriteUndo : public RewritePattern {
320   TestRegionRewriteUndo(MLIRContext *ctx)
321       : RewritePattern("test.region_builder", 1, ctx) {}
322 
323   LogicalResult matchAndRewrite(Operation *op,
324                                 PatternRewriter &rewriter) const final {
325     // Create the region operation with an entry block containing arguments.
326     OperationState newRegion(op->getLoc(), "test.region");
327     newRegion.addRegion();
328     auto *regionOp = rewriter.create(newRegion);
329     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
330     entryBlock->addArgument(rewriter.getIntegerType(64),
331                             rewriter.getUnknownLoc());
332 
333     // Add an explicitly illegal operation to ensure the conversion fails.
334     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
335     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
336 
337     // Drop this operation.
338     rewriter.eraseOp(op);
339     return success();
340   }
341 };
342 /// A simple pattern that creates a block at the end of the parent region of the
343 /// matched operation.
344 struct TestCreateBlock : public RewritePattern {
345   TestCreateBlock(MLIRContext *ctx)
346       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
347 
348   LogicalResult matchAndRewrite(Operation *op,
349                                 PatternRewriter &rewriter) const final {
350     Region &region = *op->getParentRegion();
351     Type i32Type = rewriter.getIntegerType(32);
352     Location loc = op->getLoc();
353     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
354     rewriter.create<TerminatorOp>(loc);
355     rewriter.replaceOp(op, {});
356     return success();
357   }
358 };
359 
360 /// A simple pattern that creates a block containing an invalid operation in
361 /// order to trigger the block creation undo mechanism.
362 struct TestCreateIllegalBlock : public RewritePattern {
363   TestCreateIllegalBlock(MLIRContext *ctx)
364       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
365 
366   LogicalResult matchAndRewrite(Operation *op,
367                                 PatternRewriter &rewriter) const final {
368     Region &region = *op->getParentRegion();
369     Type i32Type = rewriter.getIntegerType(32);
370     Location loc = op->getLoc();
371     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
372     // Create an illegal op to ensure the conversion fails.
373     rewriter.create<ILLegalOpF>(loc, i32Type);
374     rewriter.create<TerminatorOp>(loc);
375     rewriter.replaceOp(op, {});
376     return success();
377   }
378 };
379 
380 /// A simple pattern that tests the undo mechanism when replacing the uses of a
381 /// block argument.
382 struct TestUndoBlockArgReplace : public ConversionPattern {
383   TestUndoBlockArgReplace(MLIRContext *ctx)
384       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
385 
386   LogicalResult
387   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
388                   ConversionPatternRewriter &rewriter) const final {
389     auto illegalOp =
390         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
391     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
392                                         illegalOp);
393     rewriter.updateRootInPlace(op, [] {});
394     return success();
395   }
396 };
397 
398 /// A rewrite pattern that tests the undo mechanism when erasing a block.
399 struct TestUndoBlockErase : public ConversionPattern {
400   TestUndoBlockErase(MLIRContext *ctx)
401       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
402 
403   LogicalResult
404   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
405                   ConversionPatternRewriter &rewriter) const final {
406     Block *secondBlock = &*std::next(op->getRegion(0).begin());
407     rewriter.setInsertionPointToStart(secondBlock);
408     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
409     rewriter.eraseBlock(secondBlock);
410     rewriter.updateRootInPlace(op, [] {});
411     return success();
412   }
413 };
414 
415 //===----------------------------------------------------------------------===//
416 // Type-Conversion Rewrite Testing
417 
418 /// This patterns erases a region operation that has had a type conversion.
419 struct TestDropOpSignatureConversion : public ConversionPattern {
420   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
421       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
422   LogicalResult
423   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
424                   ConversionPatternRewriter &rewriter) const override {
425     Region &region = op->getRegion(0);
426     Block *entry = &region.front();
427 
428     // Convert the original entry arguments.
429     TypeConverter &converter = *getTypeConverter();
430     TypeConverter::SignatureConversion result(entry->getNumArguments());
431     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
432                                               result)) ||
433         failed(rewriter.convertRegionTypes(&region, converter, &result)))
434       return failure();
435 
436     // Convert the region signature and just drop the operation.
437     rewriter.eraseOp(op);
438     return success();
439   }
440 };
441 /// This pattern simply updates the operands of the given operation.
442 struct TestPassthroughInvalidOp : public ConversionPattern {
443   TestPassthroughInvalidOp(MLIRContext *ctx)
444       : ConversionPattern("test.invalid", 1, ctx) {}
445   LogicalResult
446   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
447                   ConversionPatternRewriter &rewriter) const final {
448     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
449                                              llvm::None);
450     return success();
451   }
452 };
453 /// This pattern handles the case of a split return value.
454 struct TestSplitReturnType : public ConversionPattern {
455   TestSplitReturnType(MLIRContext *ctx)
456       : ConversionPattern("test.return", 1, ctx) {}
457   LogicalResult
458   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
459                   ConversionPatternRewriter &rewriter) const final {
460     // Check for a return of F32.
461     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
462       return failure();
463 
464     // Check if the first operation is a cast operation, if it is we use the
465     // results directly.
466     auto *defOp = operands[0].getDefiningOp();
467     if (auto packerOp =
468             llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
469       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
470       return success();
471     }
472 
473     // Otherwise, fail to match.
474     return failure();
475   }
476 };
477 
478 //===----------------------------------------------------------------------===//
479 // Multi-Level Type-Conversion Rewrite Testing
480 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
481   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
482       : ConversionPattern("test.type_producer", 1, ctx) {}
483   LogicalResult
484   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
485                   ConversionPatternRewriter &rewriter) const final {
486     // If the type is I32, change the type to F32.
487     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
488       return failure();
489     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
490     return success();
491   }
492 };
493 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
494   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
495       : ConversionPattern("test.type_producer", 1, ctx) {}
496   LogicalResult
497   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
498                   ConversionPatternRewriter &rewriter) const final {
499     // If the type is F32, change the type to F64.
500     if (!Type(*op->result_type_begin()).isF32())
501       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
502     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
503     return success();
504   }
505 };
506 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
507   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
508       : ConversionPattern("test.type_producer", 10, ctx) {}
509   LogicalResult
510   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
511                   ConversionPatternRewriter &rewriter) const final {
512     // Always convert to B16, even though it is not a legal type. This tests
513     // that values are unmapped correctly.
514     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
515     return success();
516   }
517 };
518 struct TestUpdateConsumerType : public ConversionPattern {
519   TestUpdateConsumerType(MLIRContext *ctx)
520       : ConversionPattern("test.type_consumer", 1, ctx) {}
521   LogicalResult
522   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
523                   ConversionPatternRewriter &rewriter) const final {
524     // Verify that the incoming operand has been successfully remapped to F64.
525     if (!operands[0].getType().isF64())
526       return failure();
527     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
528     return success();
529   }
530 };
531 
532 //===----------------------------------------------------------------------===//
533 // Non-Root Replacement Rewrite Testing
534 /// This pattern generates an invalid operation, but replaces it before the
535 /// pattern is finished. This checks that we don't need to legalize the
536 /// temporary op.
537 struct TestNonRootReplacement : public RewritePattern {
538   TestNonRootReplacement(MLIRContext *ctx)
539       : RewritePattern("test.replace_non_root", 1, ctx) {}
540 
541   LogicalResult matchAndRewrite(Operation *op,
542                                 PatternRewriter &rewriter) const final {
543     auto resultType = *op->result_type_begin();
544     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
545     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
546 
547     rewriter.replaceOp(illegalOp, {legalOp});
548     rewriter.replaceOp(op, {illegalOp});
549     return success();
550   }
551 };
552 
553 //===----------------------------------------------------------------------===//
554 // Recursive Rewrite Testing
555 /// This pattern is applied to the same operation multiple times, but has a
556 /// bounded recursion.
557 struct TestBoundedRecursiveRewrite
558     : public OpRewritePattern<TestRecursiveRewriteOp> {
559   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
560 
561   void initialize() {
562     // The conversion target handles bounding the recursion of this pattern.
563     setHasBoundedRewriteRecursion();
564   }
565 
566   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
567                                 PatternRewriter &rewriter) const final {
568     // Decrement the depth of the op in-place.
569     rewriter.updateRootInPlace(op, [&] {
570       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
571     });
572     return success();
573   }
574 };
575 
576 struct TestNestedOpCreationUndoRewrite
577     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
578   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
579 
580   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
581                                 PatternRewriter &rewriter) const final {
582     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
583     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
584     return success();
585   };
586 };
587 
588 // This pattern matches `test.blackhole` and delete this op and its producer.
589 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
590   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
591 
592   LogicalResult matchAndRewrite(BlackHoleOp op,
593                                 PatternRewriter &rewriter) const final {
594     Operation *producer = op.getOperand().getDefiningOp();
595     // Always erase the user before the producer, the framework should handle
596     // this correctly.
597     rewriter.eraseOp(op);
598     rewriter.eraseOp(producer);
599     return success();
600   };
601 };
602 
603 // This pattern replaces explicitly illegal op with explicitly legal op,
604 // but in addition creates unregistered operation.
605 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
606   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
607 
608   LogicalResult matchAndRewrite(ILLegalOpG op,
609                                 PatternRewriter &rewriter) const final {
610     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
611     Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
612     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
613     return success();
614   };
615 };
616 } // namespace
617 
618 namespace {
619 struct TestTypeConverter : public TypeConverter {
620   using TypeConverter::TypeConverter;
621   TestTypeConverter() {
622     addConversion(convertType);
623     addArgumentMaterialization(materializeCast);
624     addSourceMaterialization(materializeCast);
625   }
626 
627   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
628     // Drop I16 types.
629     if (t.isSignlessInteger(16))
630       return success();
631 
632     // Convert I64 to F64.
633     if (t.isSignlessInteger(64)) {
634       results.push_back(FloatType::getF64(t.getContext()));
635       return success();
636     }
637 
638     // Convert I42 to I43.
639     if (t.isInteger(42)) {
640       results.push_back(IntegerType::get(t.getContext(), 43));
641       return success();
642     }
643 
644     // Split F32 into F16,F16.
645     if (t.isF32()) {
646       results.assign(2, FloatType::getF16(t.getContext()));
647       return success();
648     }
649 
650     // Otherwise, convert the type directly.
651     results.push_back(t);
652     return success();
653   }
654 
655   /// Hook for materializing a conversion. This is necessary because we generate
656   /// 1->N type mappings.
657   static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
658                                          ValueRange inputs, Location loc) {
659     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
660   }
661 };
662 
663 struct TestLegalizePatternDriver
664     : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
665   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
666 
667   StringRef getArgument() const final { return "test-legalize-patterns"; }
668   StringRef getDescription() const final {
669     return "Run test dialect legalization patterns";
670   }
671   /// The mode of conversion to use with the driver.
672   enum class ConversionMode { Analysis, Full, Partial };
673 
674   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
675 
676   void getDependentDialects(DialectRegistry &registry) const override {
677     registry.insert<func::FuncDialect>();
678   }
679 
680   void runOnOperation() override {
681     TestTypeConverter converter;
682     mlir::RewritePatternSet patterns(&getContext());
683     populateWithGenerated(patterns);
684     patterns
685         .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
686              TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace,
687              TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType,
688              TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
689              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
690              TestNonRootReplacement, TestBoundedRecursiveRewrite,
691              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
692              TestCreateUnregisteredOp>(&getContext());
693     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
694     mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
695         patterns, converter);
696     mlir::populateCallOpTypeConversionPattern(patterns, converter);
697 
698     // Define the conversion target used for the test.
699     ConversionTarget target(getContext());
700     target.addLegalOp<ModuleOp>();
701     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
702                       TerminatorOp>();
703     target
704         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
705     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
706       // Don't allow F32 operands.
707       return llvm::none_of(op.getOperandTypes(),
708                            [](Type type) { return type.isF32(); });
709     });
710     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
711       return converter.isSignatureLegal(op.getFunctionType()) &&
712              converter.isLegal(&op.getBody());
713     });
714     target.addDynamicallyLegalOp<func::CallOp>(
715         [&](func::CallOp op) { return converter.isLegal(op); });
716 
717     // TestCreateUnregisteredOp creates `arith.constant` operation,
718     // which was not added to target intentionally to test
719     // correct error code from conversion driver.
720     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
721 
722     // Expect the type_producer/type_consumer operations to only operate on f64.
723     target.addDynamicallyLegalOp<TestTypeProducerOp>(
724         [](TestTypeProducerOp op) { return op.getType().isF64(); });
725     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
726       return op.getOperand().getType().isF64();
727     });
728 
729     // Check support for marking certain operations as recursively legal.
730     target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
731       return static_cast<bool>(
732           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
733     });
734 
735     // Mark the bound recursion operation as dynamically legal.
736     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
737         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
738 
739     // Handle a partial conversion.
740     if (mode == ConversionMode::Partial) {
741       DenseSet<Operation *> unlegalizedOps;
742       if (failed(applyPartialConversion(
743               getOperation(), target, std::move(patterns), &unlegalizedOps))) {
744         getOperation()->emitRemark() << "applyPartialConversion failed";
745       }
746       // Emit remarks for each legalizable operation.
747       for (auto *op : unlegalizedOps)
748         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
749       return;
750     }
751 
752     // Handle a full conversion.
753     if (mode == ConversionMode::Full) {
754       // Check support for marking unknown operations as dynamically legal.
755       target.markUnknownOpDynamicallyLegal([](Operation *op) {
756         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
757       });
758 
759       if (failed(applyFullConversion(getOperation(), target,
760                                      std::move(patterns)))) {
761         getOperation()->emitRemark() << "applyFullConversion failed";
762       }
763       return;
764     }
765 
766     // Otherwise, handle an analysis conversion.
767     assert(mode == ConversionMode::Analysis);
768 
769     // Analyze the convertible operations.
770     DenseSet<Operation *> legalizedOps;
771     if (failed(applyAnalysisConversion(getOperation(), target,
772                                        std::move(patterns), legalizedOps)))
773       return signalPassFailure();
774 
775     // Emit remarks for each legalizable operation.
776     for (auto *op : legalizedOps)
777       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
778   }
779 
780   /// The mode of conversion to use.
781   ConversionMode mode;
782 };
783 } // namespace
784 
785 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
786     legalizerConversionMode(
787         "test-legalize-mode",
788         llvm::cl::desc("The legalization mode to use with the test driver"),
789         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
790         llvm::cl::values(
791             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
792                        "analysis", "Perform an analysis conversion"),
793             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
794                        "Perform a full conversion"),
795             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
796                        "partial", "Perform a partial conversion")));
797 
798 //===----------------------------------------------------------------------===//
799 // ConversionPatternRewriter::getRemappedValue testing. This method is used
800 // to get the remapped value of an original value that was replaced using
801 // ConversionPatternRewriter.
802 namespace {
803 struct TestRemapValueTypeConverter : public TypeConverter {
804   using TypeConverter::TypeConverter;
805 
806   TestRemapValueTypeConverter() {
807     addConversion(
808         [](Float32Type type) { return Float64Type::get(type.getContext()); });
809     addConversion([](Type type) { return type; });
810   }
811 };
812 
813 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
814 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
815 /// operand twice.
816 ///
817 /// Example:
818 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
819 /// is replaced with:
820 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
821 struct OneVResOneVOperandOp1Converter
822     : public OpConversionPattern<OneVResOneVOperandOp1> {
823   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
824 
825   LogicalResult
826   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
827                   ConversionPatternRewriter &rewriter) const override {
828     auto origOps = op.getOperands();
829     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
830            "One operand expected");
831     Value origOp = *origOps.begin();
832     SmallVector<Value, 2> remappedOperands;
833     // Replicate the remapped original operand twice. Note that we don't used
834     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
835     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
836     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
837 
838     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
839                                                        remappedOperands);
840     return success();
841   }
842 };
843 
844 /// A rewriter pattern that tests that blocks can be merged.
845 struct TestRemapValueInRegion
846     : public OpConversionPattern<TestRemappedValueRegionOp> {
847   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
848 
849   LogicalResult
850   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
851                   ConversionPatternRewriter &rewriter) const final {
852     Block &block = op.getBody().front();
853     Operation *terminator = block.getTerminator();
854 
855     // Merge the block into the parent region.
856     Block *parentBlock = op->getBlock();
857     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
858     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
859     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
860 
861     // Replace the results of this operation with the remapped terminator
862     // values.
863     SmallVector<Value> terminatorOperands;
864     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
865                                           terminatorOperands)))
866       return failure();
867 
868     rewriter.eraseOp(terminator);
869     rewriter.replaceOp(op, terminatorOperands);
870     return success();
871   }
872 };
873 
874 struct TestRemappedValue
875     : public mlir::PassWrapper<TestRemappedValue, OperationPass<func::FuncOp>> {
876   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
877 
878   StringRef getArgument() const final { return "test-remapped-value"; }
879   StringRef getDescription() const final {
880     return "Test public remapped value mechanism in ConversionPatternRewriter";
881   }
882   void runOnOperation() override {
883     TestRemapValueTypeConverter typeConverter;
884 
885     mlir::RewritePatternSet patterns(&getContext());
886     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
887     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
888         &getContext());
889     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
890 
891     mlir::ConversionTarget target(getContext());
892     target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
893 
894     // Expect the type_producer/type_consumer operations to only operate on f64.
895     target.addDynamicallyLegalOp<TestTypeProducerOp>(
896         [](TestTypeProducerOp op) { return op.getType().isF64(); });
897     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
898       return op.getOperand().getType().isF64();
899     });
900 
901     // We make OneVResOneVOperandOp1 legal only when it has more that one
902     // operand. This will trigger the conversion that will replace one-operand
903     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
904     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
905         [](Operation *op) { return op->getNumOperands() > 1; });
906 
907     if (failed(mlir::applyFullConversion(getOperation(), target,
908                                          std::move(patterns)))) {
909       signalPassFailure();
910     }
911   }
912 };
913 } // namespace
914 
915 //===----------------------------------------------------------------------===//
916 // Test patterns without a specific root operation kind
917 //===----------------------------------------------------------------------===//
918 
919 namespace {
920 /// This pattern matches and removes any operation in the test dialect.
921 struct RemoveTestDialectOps : public RewritePattern {
922   RemoveTestDialectOps(MLIRContext *context)
923       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
924 
925   LogicalResult matchAndRewrite(Operation *op,
926                                 PatternRewriter &rewriter) const override {
927     if (!isa<TestDialect>(op->getDialect()))
928       return failure();
929     rewriter.eraseOp(op);
930     return success();
931   }
932 };
933 
934 struct TestUnknownRootOpDriver
935     : public mlir::PassWrapper<TestUnknownRootOpDriver,
936                                OperationPass<func::FuncOp>> {
937   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
938 
939   StringRef getArgument() const final {
940     return "test-legalize-unknown-root-patterns";
941   }
942   StringRef getDescription() const final {
943     return "Test public remapped value mechanism in ConversionPatternRewriter";
944   }
945   void runOnOperation() override {
946     mlir::RewritePatternSet patterns(&getContext());
947     patterns.add<RemoveTestDialectOps>(&getContext());
948 
949     mlir::ConversionTarget target(getContext());
950     target.addIllegalDialect<TestDialect>();
951     if (failed(applyPartialConversion(getOperation(), target,
952                                       std::move(patterns))))
953       signalPassFailure();
954   }
955 };
956 } // namespace
957 
958 //===----------------------------------------------------------------------===//
959 // Test patterns that uses operations and types defined at runtime
960 //===----------------------------------------------------------------------===//
961 
962 namespace {
963 /// This pattern matches dynamic operations 'test.one_operand_two_results' and
964 /// replace them with dynamic operations 'test.generic_dynamic_op'.
965 struct RewriteDynamicOp : public RewritePattern {
966   RewriteDynamicOp(MLIRContext *context)
967       : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
968                        context) {}
969 
970   LogicalResult matchAndRewrite(Operation *op,
971                                 PatternRewriter &rewriter) const override {
972     assert(op->getName().getStringRef() ==
973                "test.dynamic_one_operand_two_results" &&
974            "rewrite pattern should only match operations with the right name");
975 
976     OperationState state(op->getLoc(), "test.dynamic_generic",
977                          op->getOperands(), op->getResultTypes(),
978                          op->getAttrs());
979     auto *newOp = rewriter.create(state);
980     rewriter.replaceOp(op, newOp->getResults());
981     return success();
982   }
983 };
984 
985 struct TestRewriteDynamicOpDriver
986     : public PassWrapper<TestRewriteDynamicOpDriver,
987                          OperationPass<func::FuncOp>> {
988   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
989 
990   void getDependentDialects(DialectRegistry &registry) const override {
991     registry.insert<TestDialect>();
992   }
993   StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
994   StringRef getDescription() const final {
995     return "Test rewritting on dynamic operations";
996   }
997   void runOnOperation() override {
998     RewritePatternSet patterns(&getContext());
999     patterns.add<RewriteDynamicOp>(&getContext());
1000 
1001     ConversionTarget target(getContext());
1002     target.addIllegalOp(
1003         OperationName("test.dynamic_one_operand_two_results", &getContext()));
1004     target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
1005     if (failed(applyPartialConversion(getOperation(), target,
1006                                       std::move(patterns))))
1007       signalPassFailure();
1008   }
1009 };
1010 } // end anonymous namespace
1011 
1012 //===----------------------------------------------------------------------===//
1013 // Test type conversions
1014 //===----------------------------------------------------------------------===//
1015 
1016 namespace {
1017 struct TestTypeConversionProducer
1018     : public OpConversionPattern<TestTypeProducerOp> {
1019   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
1020   LogicalResult
1021   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
1022                   ConversionPatternRewriter &rewriter) const final {
1023     Type resultType = op.getType();
1024     Type convertedType = getTypeConverter()
1025                              ? getTypeConverter()->convertType(resultType)
1026                              : resultType;
1027     if (resultType.isa<FloatType>())
1028       resultType = rewriter.getF64Type();
1029     else if (resultType.isInteger(16))
1030       resultType = rewriter.getIntegerType(64);
1031     else if (resultType.isa<test::TestRecursiveType>() &&
1032              convertedType != resultType)
1033       resultType = convertedType;
1034     else
1035       return failure();
1036 
1037     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
1038     return success();
1039   }
1040 };
1041 
1042 /// Call signature conversion and then fail the rewrite to trigger the undo
1043 /// mechanism.
1044 struct TestSignatureConversionUndo
1045     : public OpConversionPattern<TestSignatureConversionUndoOp> {
1046   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
1047 
1048   LogicalResult
1049   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
1050                   ConversionPatternRewriter &rewriter) const final {
1051     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
1052     return failure();
1053   }
1054 };
1055 
1056 /// Call signature conversion without providing a type converter to handle
1057 /// materializations.
1058 struct TestTestSignatureConversionNoConverter
1059     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1060   TestTestSignatureConversionNoConverter(TypeConverter &converter,
1061                                          MLIRContext *context)
1062       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1063         converter(converter) {}
1064 
1065   LogicalResult
1066   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1067                   ConversionPatternRewriter &rewriter) const final {
1068     Region &region = op->getRegion(0);
1069     Block *entry = &region.front();
1070 
1071     // Convert the original entry arguments.
1072     TypeConverter::SignatureConversion result(entry->getNumArguments());
1073     if (failed(
1074             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
1075       return failure();
1076     rewriter.updateRootInPlace(
1077         op, [&] { rewriter.applySignatureConversion(&region, result); });
1078     return success();
1079   }
1080 
1081   TypeConverter &converter;
1082 };
1083 
1084 /// Just forward the operands to the root op. This is essentially a no-op
1085 /// pattern that is used to trigger target materialization.
1086 struct TestTypeConsumerForward
1087     : public OpConversionPattern<TestTypeConsumerOp> {
1088   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1089 
1090   LogicalResult
1091   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1092                   ConversionPatternRewriter &rewriter) const final {
1093     rewriter.updateRootInPlace(op,
1094                                [&] { op->setOperands(adaptor.getOperands()); });
1095     return success();
1096   }
1097 };
1098 
1099 struct TestTypeConversionAnotherProducer
1100     : public OpRewritePattern<TestAnotherTypeProducerOp> {
1101   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1102 
1103   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1104                                 PatternRewriter &rewriter) const final {
1105     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1106     return success();
1107   }
1108 };
1109 
1110 struct TestTypeConversionDriver
1111     : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
1112   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1113 
1114   void getDependentDialects(DialectRegistry &registry) const override {
1115     registry.insert<TestDialect>();
1116   }
1117   StringRef getArgument() const final {
1118     return "test-legalize-type-conversion";
1119   }
1120   StringRef getDescription() const final {
1121     return "Test various type conversion functionalities in DialectConversion";
1122   }
1123 
1124   void runOnOperation() override {
1125     // Initialize the type converter.
1126     TypeConverter converter;
1127 
1128     /// Add the legal set of type conversions.
1129     converter.addConversion([](Type type) -> Type {
1130       // Treat F64 as legal.
1131       if (type.isF64())
1132         return type;
1133       // Allow converting BF16/F16/F32 to F64.
1134       if (type.isBF16() || type.isF16() || type.isF32())
1135         return FloatType::getF64(type.getContext());
1136       // Otherwise, the type is illegal.
1137       return nullptr;
1138     });
1139     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1140       // Drop all integer types.
1141       return success();
1142     });
1143     converter.addConversion(
1144         // Convert a recursive self-referring type into a non-self-referring
1145         // type named "outer_converted_type" that contains a SimpleAType.
1146         [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results,
1147             ArrayRef<Type> callStack) -> Optional<LogicalResult> {
1148           // If the type is already converted, return it to indicate that it is
1149           // legal.
1150           if (type.getName() == "outer_converted_type") {
1151             results.push_back(type);
1152             return success();
1153           }
1154 
1155           // If the type is on the call stack more than once (it is there at
1156           // least once because of the _current_ call, which is always the last
1157           // element on the stack), we've hit the recursive case. Just return
1158           // SimpleAType here to create a non-recursive type as a result.
1159           if (llvm::is_contained(callStack.drop_back(), type)) {
1160             results.push_back(test::SimpleAType::get(type.getContext()));
1161             return success();
1162           }
1163 
1164           // Convert the body recursively.
1165           auto result = test::TestRecursiveType::get(type.getContext(),
1166                                                      "outer_converted_type");
1167           if (failed(result.setBody(converter.convertType(type.getBody()))))
1168             return failure();
1169           results.push_back(result);
1170           return success();
1171         });
1172 
1173     /// Add the legal set of type materializations.
1174     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1175                                           ValueRange inputs,
1176                                           Location loc) -> Value {
1177       // Allow casting from F64 back to F32.
1178       if (!resultType.isF16() && inputs.size() == 1 &&
1179           inputs[0].getType().isF64())
1180         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1181       // Allow producing an i32 or i64 from nothing.
1182       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1183           inputs.empty())
1184         return builder.create<TestTypeProducerOp>(loc, resultType);
1185       // Allow producing an i64 from an integer.
1186       if (resultType.isa<IntegerType>() && inputs.size() == 1 &&
1187           inputs[0].getType().isa<IntegerType>())
1188         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1189       // Otherwise, fail.
1190       return nullptr;
1191     });
1192 
1193     // Initialize the conversion target.
1194     mlir::ConversionTarget target(getContext());
1195     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1196       auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>();
1197       return op.getType().isF64() || op.getType().isInteger(64) ||
1198              (recursiveType &&
1199               recursiveType.getName() == "outer_converted_type");
1200     });
1201     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1202       return converter.isSignatureLegal(op.getFunctionType()) &&
1203              converter.isLegal(&op.getBody());
1204     });
1205     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1206       // Allow casts from F64 to F32.
1207       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1208     });
1209     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1210         [&](TestSignatureConversionNoConverterOp op) {
1211           return converter.isLegal(op.getRegion().front().getArgumentTypes());
1212         });
1213 
1214     // Initialize the set of rewrite patterns.
1215     RewritePatternSet patterns(&getContext());
1216     patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1217                  TestSignatureConversionUndo,
1218                  TestTestSignatureConversionNoConverter>(converter,
1219                                                          &getContext());
1220     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
1221     mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
1222         patterns, converter);
1223 
1224     if (failed(applyPartialConversion(getOperation(), target,
1225                                       std::move(patterns))))
1226       signalPassFailure();
1227   }
1228 };
1229 } // namespace
1230 
1231 //===----------------------------------------------------------------------===//
1232 // Test Target Materialization With No Uses
1233 //===----------------------------------------------------------------------===//
1234 
1235 namespace {
1236 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1237   using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1238 
1239   LogicalResult
1240   matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1241                   ConversionPatternRewriter &rewriter) const final {
1242     rewriter.replaceOp(op, adaptor.getOperands());
1243     return success();
1244   }
1245 };
1246 
1247 struct TestTargetMaterializationWithNoUses
1248     : public PassWrapper<TestTargetMaterializationWithNoUses,
1249                          OperationPass<ModuleOp>> {
1250   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1251       TestTargetMaterializationWithNoUses)
1252 
1253   StringRef getArgument() const final {
1254     return "test-target-materialization-with-no-uses";
1255   }
1256   StringRef getDescription() const final {
1257     return "Test a special case of target materialization in DialectConversion";
1258   }
1259 
1260   void runOnOperation() override {
1261     TypeConverter converter;
1262     converter.addConversion([](Type t) { return t; });
1263     converter.addConversion([](IntegerType intTy) -> Type {
1264       if (intTy.getWidth() == 16)
1265         return IntegerType::get(intTy.getContext(), 64);
1266       return intTy;
1267     });
1268     converter.addTargetMaterialization(
1269         [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1270           return builder.create<TestCastOp>(loc, type, inputs).getResult();
1271         });
1272 
1273     ConversionTarget target(getContext());
1274     target.addIllegalOp<TestTypeChangerOp>();
1275 
1276     RewritePatternSet patterns(&getContext());
1277     patterns.add<ForwardOperandPattern>(converter, &getContext());
1278 
1279     if (failed(applyPartialConversion(getOperation(), target,
1280                                       std::move(patterns))))
1281       signalPassFailure();
1282   }
1283 };
1284 } // namespace
1285 
1286 //===----------------------------------------------------------------------===//
1287 // Test Block Merging
1288 //===----------------------------------------------------------------------===//
1289 
1290 namespace {
1291 /// A rewriter pattern that tests that blocks can be merged.
1292 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1293   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1294 
1295   LogicalResult
1296   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1297                   ConversionPatternRewriter &rewriter) const final {
1298     Block &firstBlock = op.getBody().front();
1299     Operation *branchOp = firstBlock.getTerminator();
1300     Block *secondBlock = &*(std::next(op.getBody().begin()));
1301     auto succOperands = branchOp->getOperands();
1302     SmallVector<Value, 2> replacements(succOperands);
1303     rewriter.eraseOp(branchOp);
1304     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1305     rewriter.updateRootInPlace(op, [] {});
1306     return success();
1307   }
1308 };
1309 
1310 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1311 struct TestUndoBlocksMerge : public ConversionPattern {
1312   TestUndoBlocksMerge(MLIRContext *ctx)
1313       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1314   LogicalResult
1315   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1316                   ConversionPatternRewriter &rewriter) const final {
1317     Block &firstBlock = op->getRegion(0).front();
1318     Operation *branchOp = firstBlock.getTerminator();
1319     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1320     rewriter.setInsertionPointToStart(secondBlock);
1321     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1322     auto succOperands = branchOp->getOperands();
1323     SmallVector<Value, 2> replacements(succOperands);
1324     rewriter.eraseOp(branchOp);
1325     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1326     rewriter.updateRootInPlace(op, [] {});
1327     return success();
1328   }
1329 };
1330 
1331 /// A rewrite mechanism to inline the body of the op into its parent, when both
1332 /// ops can have a single block.
1333 struct TestMergeSingleBlockOps
1334     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1335   using OpConversionPattern<
1336       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1337 
1338   LogicalResult
1339   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1340                   ConversionPatternRewriter &rewriter) const final {
1341     SingleBlockImplicitTerminatorOp parentOp =
1342         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1343     if (!parentOp)
1344       return failure();
1345     Block &innerBlock = op.getRegion().front();
1346     TerminatorOp innerTerminator =
1347         cast<TerminatorOp>(innerBlock.getTerminator());
1348     rewriter.mergeBlockBefore(&innerBlock, op);
1349     rewriter.eraseOp(innerTerminator);
1350     rewriter.eraseOp(op);
1351     rewriter.updateRootInPlace(op, [] {});
1352     return success();
1353   }
1354 };
1355 
1356 struct TestMergeBlocksPatternDriver
1357     : public PassWrapper<TestMergeBlocksPatternDriver,
1358                          OperationPass<ModuleOp>> {
1359   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
1360 
1361   StringRef getArgument() const final { return "test-merge-blocks"; }
1362   StringRef getDescription() const final {
1363     return "Test Merging operation in ConversionPatternRewriter";
1364   }
1365   void runOnOperation() override {
1366     MLIRContext *context = &getContext();
1367     mlir::RewritePatternSet patterns(context);
1368     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1369         context);
1370     ConversionTarget target(*context);
1371     target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1372                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1373     target.addIllegalOp<ILLegalOpF>();
1374 
1375     /// Expect the op to have a single block after legalization.
1376     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1377         [&](TestMergeBlocksOp op) -> bool {
1378           return llvm::hasSingleElement(op.getBody());
1379         });
1380 
1381     /// Only allow `test.br` within test.merge_blocks op.
1382     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
1383       return op->getParentOfType<TestMergeBlocksOp>();
1384     });
1385 
1386     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1387     /// inlined.
1388     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1389         [&](SingleBlockImplicitTerminatorOp op) -> bool {
1390           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1391         });
1392 
1393     DenseSet<Operation *> unlegalizedOps;
1394     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1395                                  &unlegalizedOps);
1396     for (auto *op : unlegalizedOps)
1397       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1398   }
1399 };
1400 } // namespace
1401 
1402 //===----------------------------------------------------------------------===//
1403 // Test Selective Replacement
1404 //===----------------------------------------------------------------------===//
1405 
1406 namespace {
1407 /// A rewrite mechanism to inline the body of the op into its parent, when both
1408 /// ops can have a single block.
1409 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
1410   using OpRewritePattern<TestCastOp>::OpRewritePattern;
1411 
1412   LogicalResult matchAndRewrite(TestCastOp op,
1413                                 PatternRewriter &rewriter) const final {
1414     if (op.getNumOperands() != 2)
1415       return failure();
1416     OperandRange operands = op.getOperands();
1417 
1418     // Replace non-terminator uses with the first operand.
1419     rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
1420       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
1421     });
1422     // Replace everything else with the second operand if the operation isn't
1423     // dead.
1424     rewriter.replaceOp(op, op.getOperand(1));
1425     return success();
1426   }
1427 };
1428 
1429 struct TestSelectiveReplacementPatternDriver
1430     : public PassWrapper<TestSelectiveReplacementPatternDriver,
1431                          OperationPass<>> {
1432   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1433       TestSelectiveReplacementPatternDriver)
1434 
1435   StringRef getArgument() const final {
1436     return "test-pattern-selective-replacement";
1437   }
1438   StringRef getDescription() const final {
1439     return "Test selective replacement in the PatternRewriter";
1440   }
1441   void runOnOperation() override {
1442     MLIRContext *context = &getContext();
1443     mlir::RewritePatternSet patterns(context);
1444     patterns.add<TestSelectiveOpReplacementPattern>(context);
1445     (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
1446                                        std::move(patterns));
1447   }
1448 };
1449 } // namespace
1450 
1451 //===----------------------------------------------------------------------===//
1452 // PassRegistration
1453 //===----------------------------------------------------------------------===//
1454 
1455 namespace mlir {
1456 namespace test {
1457 void registerPatternsTestPass() {
1458   PassRegistration<TestReturnTypeDriver>();
1459 
1460   PassRegistration<TestDerivedAttributeDriver>();
1461 
1462   PassRegistration<TestPatternDriver>();
1463 
1464   PassRegistration<TestLegalizePatternDriver>([] {
1465     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
1466   });
1467 
1468   PassRegistration<TestRemappedValue>();
1469 
1470   PassRegistration<TestUnknownRootOpDriver>();
1471 
1472   PassRegistration<TestTypeConversionDriver>();
1473   PassRegistration<TestTargetMaterializationWithNoUses>();
1474 
1475   PassRegistration<TestRewriteDynamicOpDriver>();
1476 
1477   PassRegistration<TestMergeBlocksPatternDriver>();
1478   PassRegistration<TestSelectiveReplacementPatternDriver>();
1479 }
1480 } // namespace test
1481 } // namespace mlir
1482