1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/StandardOps/IR/Ops.h"
13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "mlir/Transforms/FoldUtils.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace test;
23 
24 // Native function for testing NativeCodeCall
25 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
26   return choice.getValue() ? input1 : input2;
27 }
28 
29 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
30   rewriter.create<OpI>(loc, input);
31 }
32 
33 static void handleNoResultOp(PatternRewriter &rewriter,
34                              OpSymbolBindingNoResult op) {
35   // Turn the no result op to a one-result op.
36   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
37                                     op.getOperand());
38 }
39 
40 static bool getFirstI32Result(Operation *op, Value &value) {
41   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
42     return false;
43   value = op->getResult(0);
44   return true;
45 }
46 
47 static Value bindNativeCodeCallResult(Value value) { return value; }
48 
49 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
50                                                               Value input2) {
51   return SmallVector<Value, 2>({input2, input1});
52 }
53 
54 // Test that natives calls are only called once during rewrites.
55 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
56 // This let us check the number of times OpM_Test was called by inspecting
57 // the returned value in the MLIR output.
58 static int64_t opMIncreasingValue = 314159265;
59 static Attribute opMTest(PatternRewriter &rewriter, Value val) {
60   int64_t i = opMIncreasingValue++;
61   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
62 }
63 
64 namespace {
65 #include "TestPatterns.inc"
66 } // namespace
67 
68 //===----------------------------------------------------------------------===//
69 // Test Reduce Pattern Interface
70 //===----------------------------------------------------------------------===//
71 
72 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
73   populateWithGenerated(patterns);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Canonicalizer Driver.
78 //===----------------------------------------------------------------------===//
79 
80 namespace {
81 struct FoldingPattern : public RewritePattern {
82 public:
83   FoldingPattern(MLIRContext *context)
84       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
85                        /*benefit=*/1, context) {}
86 
87   LogicalResult matchAndRewrite(Operation *op,
88                                 PatternRewriter &rewriter) const override {
89     // Exercise OperationFolder API for a single-result operation that is folded
90     // upon construction. The operation being created through the folder has an
91     // in-place folder, and it should be still present in the output.
92     // Furthermore, the folder should not crash when attempting to recover the
93     // (unchanged) operation result.
94     OperationFolder folder(op->getContext());
95     Value result = folder.create<TestOpInPlaceFold>(
96         rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
97         rewriter.getI32IntegerAttr(0));
98     assert(result);
99     rewriter.replaceOp(op, result);
100     return success();
101   }
102 };
103 
104 /// This pattern creates a foldable operation at the entry point of the block.
105 /// This tests the situation where the operation folder will need to replace an
106 /// operation with a previously created constant that does not initially
107 /// dominate the operation to replace.
108 struct FolderInsertBeforePreviouslyFoldedConstantPattern
109     : public OpRewritePattern<TestCastOp> {
110 public:
111   using OpRewritePattern<TestCastOp>::OpRewritePattern;
112 
113   LogicalResult matchAndRewrite(TestCastOp op,
114                                 PatternRewriter &rewriter) const override {
115     if (!op->hasAttr("test_fold_before_previously_folded_op"))
116       return failure();
117     rewriter.setInsertionPointToStart(op->getBlock());
118 
119     auto constOp = rewriter.create<arith::ConstantOp>(
120         op.getLoc(), rewriter.getBoolAttr(true));
121     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
122                                             Value(constOp));
123     return success();
124   }
125 };
126 
127 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
128   StringRef getArgument() const final { return "test-patterns"; }
129   StringRef getDescription() const final { return "Run test dialect patterns"; }
130   void runOnFunction() override {
131     mlir::RewritePatternSet patterns(&getContext());
132     populateWithGenerated(patterns);
133 
134     // Verify named pattern is generated with expected name.
135     patterns.add<FoldingPattern, TestNamedPatternRule,
136                  FolderInsertBeforePreviouslyFoldedConstantPattern>(
137         &getContext());
138 
139     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
140   }
141 };
142 } // namespace
143 
144 //===----------------------------------------------------------------------===//
145 // ReturnType Driver.
146 //===----------------------------------------------------------------------===//
147 
148 namespace {
149 // Generate ops for each instance where the type can be successfully inferred.
150 template <typename OpTy>
151 static void invokeCreateWithInferredReturnType(Operation *op) {
152   auto *context = op->getContext();
153   auto fop = op->getParentOfType<FuncOp>();
154   auto location = UnknownLoc::get(context);
155   OpBuilder b(op);
156   b.setInsertionPointAfter(op);
157 
158   // Use permutations of 2 args as operands.
159   assert(fop.getNumArguments() >= 2);
160   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
161     for (int j = 0; j < e; ++j) {
162       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
163       SmallVector<Type, 2> inferredReturnTypes;
164       if (succeeded(OpTy::inferReturnTypes(
165               context, llvm::None, values, op->getAttrDictionary(),
166               op->getRegions(), inferredReturnTypes))) {
167         OperationState state(location, OpTy::getOperationName());
168         // TODO: Expand to regions.
169         OpTy::build(b, state, values, op->getAttrs());
170         (void)b.createOperation(state);
171       }
172     }
173   }
174 }
175 
176 static void reifyReturnShape(Operation *op) {
177   OpBuilder b(op);
178 
179   // Use permutations of 2 args as operands.
180   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
181   SmallVector<Value, 2> shapes;
182   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
183       !llvm::hasSingleElement(shapes))
184     return;
185   for (auto it : llvm::enumerate(shapes)) {
186     op->emitRemark() << "value " << it.index() << ": "
187                      << it.value().getDefiningOp();
188   }
189 }
190 
191 struct TestReturnTypeDriver
192     : public PassWrapper<TestReturnTypeDriver, FunctionPass> {
193   void getDependentDialects(DialectRegistry &registry) const override {
194     registry.insert<tensor::TensorDialect>();
195   }
196   StringRef getArgument() const final { return "test-return-type"; }
197   StringRef getDescription() const final { return "Run return type functions"; }
198 
199   void runOnFunction() override {
200     if (getFunction().getName() == "testCreateFunctions") {
201       std::vector<Operation *> ops;
202       // Collect ops to avoid triggering on inserted ops.
203       for (auto &op : getFunction().getBody().front())
204         ops.push_back(&op);
205       // Generate test patterns for each, but skip terminator.
206       for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
207         // Test create method of each of the Op classes below. The resultant
208         // output would be in reverse order underneath `op` from which
209         // the attributes and regions are used.
210         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
211         invokeCreateWithInferredReturnType<
212             OpWithShapedTypeInferTypeInterfaceOp>(op);
213       };
214       return;
215     }
216     if (getFunction().getName() == "testReifyFunctions") {
217       std::vector<Operation *> ops;
218       // Collect ops to avoid triggering on inserted ops.
219       for (auto &op : getFunction().getBody().front())
220         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
221           ops.push_back(&op);
222       // Generate test patterns for each, but skip terminator.
223       for (auto *op : ops)
224         reifyReturnShape(op);
225     }
226   }
227 };
228 } // namespace
229 
230 namespace {
231 struct TestDerivedAttributeDriver
232     : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> {
233   StringRef getArgument() const final { return "test-derived-attr"; }
234   StringRef getDescription() const final {
235     return "Run test derived attributes";
236   }
237   void runOnFunction() override;
238 };
239 } // namespace
240 
241 void TestDerivedAttributeDriver::runOnFunction() {
242   getFunction().walk([](DerivedAttributeOpInterface dOp) {
243     auto dAttr = dOp.materializeDerivedAttributes();
244     if (!dAttr)
245       return;
246     for (auto d : dAttr)
247       dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
248   });
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // Legalization Driver.
253 //===----------------------------------------------------------------------===//
254 
255 namespace {
256 //===----------------------------------------------------------------------===//
257 // Region-Block Rewrite Testing
258 
259 /// This pattern is a simple pattern that inlines the first region of a given
260 /// operation into the parent region.
261 struct TestRegionRewriteBlockMovement : public ConversionPattern {
262   TestRegionRewriteBlockMovement(MLIRContext *ctx)
263       : ConversionPattern("test.region", 1, ctx) {}
264 
265   LogicalResult
266   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
267                   ConversionPatternRewriter &rewriter) const final {
268     // Inline this region into the parent region.
269     auto &parentRegion = *op->getParentRegion();
270     auto &opRegion = op->getRegion(0);
271     if (op->getAttr("legalizer.should_clone"))
272       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
273     else
274       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
275 
276     if (op->getAttr("legalizer.erase_old_blocks")) {
277       while (!opRegion.empty())
278         rewriter.eraseBlock(&opRegion.front());
279     }
280 
281     // Drop this operation.
282     rewriter.eraseOp(op);
283     return success();
284   }
285 };
286 /// This pattern is a simple pattern that generates a region containing an
287 /// illegal operation.
288 struct TestRegionRewriteUndo : public RewritePattern {
289   TestRegionRewriteUndo(MLIRContext *ctx)
290       : RewritePattern("test.region_builder", 1, ctx) {}
291 
292   LogicalResult matchAndRewrite(Operation *op,
293                                 PatternRewriter &rewriter) const final {
294     // Create the region operation with an entry block containing arguments.
295     OperationState newRegion(op->getLoc(), "test.region");
296     newRegion.addRegion();
297     auto *regionOp = rewriter.createOperation(newRegion);
298     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
299     entryBlock->addArgument(rewriter.getIntegerType(64));
300 
301     // Add an explicitly illegal operation to ensure the conversion fails.
302     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
303     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
304 
305     // Drop this operation.
306     rewriter.eraseOp(op);
307     return success();
308   }
309 };
310 /// A simple pattern that creates a block at the end of the parent region of the
311 /// matched operation.
312 struct TestCreateBlock : public RewritePattern {
313   TestCreateBlock(MLIRContext *ctx)
314       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
315 
316   LogicalResult matchAndRewrite(Operation *op,
317                                 PatternRewriter &rewriter) const final {
318     Region &region = *op->getParentRegion();
319     Type i32Type = rewriter.getIntegerType(32);
320     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
321     rewriter.create<TerminatorOp>(op->getLoc());
322     rewriter.replaceOp(op, {});
323     return success();
324   }
325 };
326 
327 /// A simple pattern that creates a block containing an invalid operation in
328 /// order to trigger the block creation undo mechanism.
329 struct TestCreateIllegalBlock : public RewritePattern {
330   TestCreateIllegalBlock(MLIRContext *ctx)
331       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
332 
333   LogicalResult matchAndRewrite(Operation *op,
334                                 PatternRewriter &rewriter) const final {
335     Region &region = *op->getParentRegion();
336     Type i32Type = rewriter.getIntegerType(32);
337     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
338     // Create an illegal op to ensure the conversion fails.
339     rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
340     rewriter.create<TerminatorOp>(op->getLoc());
341     rewriter.replaceOp(op, {});
342     return success();
343   }
344 };
345 
346 /// A simple pattern that tests the undo mechanism when replacing the uses of a
347 /// block argument.
348 struct TestUndoBlockArgReplace : public ConversionPattern {
349   TestUndoBlockArgReplace(MLIRContext *ctx)
350       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
351 
352   LogicalResult
353   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
354                   ConversionPatternRewriter &rewriter) const final {
355     auto illegalOp =
356         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
357     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
358                                         illegalOp);
359     rewriter.updateRootInPlace(op, [] {});
360     return success();
361   }
362 };
363 
364 /// A rewrite pattern that tests the undo mechanism when erasing a block.
365 struct TestUndoBlockErase : public ConversionPattern {
366   TestUndoBlockErase(MLIRContext *ctx)
367       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
368 
369   LogicalResult
370   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
371                   ConversionPatternRewriter &rewriter) const final {
372     Block *secondBlock = &*std::next(op->getRegion(0).begin());
373     rewriter.setInsertionPointToStart(secondBlock);
374     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
375     rewriter.eraseBlock(secondBlock);
376     rewriter.updateRootInPlace(op, [] {});
377     return success();
378   }
379 };
380 
381 //===----------------------------------------------------------------------===//
382 // Type-Conversion Rewrite Testing
383 
384 /// This patterns erases a region operation that has had a type conversion.
385 struct TestDropOpSignatureConversion : public ConversionPattern {
386   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
387       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
388   LogicalResult
389   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
390                   ConversionPatternRewriter &rewriter) const override {
391     Region &region = op->getRegion(0);
392     Block *entry = &region.front();
393 
394     // Convert the original entry arguments.
395     TypeConverter &converter = *getTypeConverter();
396     TypeConverter::SignatureConversion result(entry->getNumArguments());
397     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
398                                               result)) ||
399         failed(rewriter.convertRegionTypes(&region, converter, &result)))
400       return failure();
401 
402     // Convert the region signature and just drop the operation.
403     rewriter.eraseOp(op);
404     return success();
405   }
406 };
407 /// This pattern simply updates the operands of the given operation.
408 struct TestPassthroughInvalidOp : public ConversionPattern {
409   TestPassthroughInvalidOp(MLIRContext *ctx)
410       : ConversionPattern("test.invalid", 1, ctx) {}
411   LogicalResult
412   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
413                   ConversionPatternRewriter &rewriter) const final {
414     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
415                                              llvm::None);
416     return success();
417   }
418 };
419 /// This pattern handles the case of a split return value.
420 struct TestSplitReturnType : public ConversionPattern {
421   TestSplitReturnType(MLIRContext *ctx)
422       : ConversionPattern("test.return", 1, ctx) {}
423   LogicalResult
424   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
425                   ConversionPatternRewriter &rewriter) const final {
426     // Check for a return of F32.
427     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
428       return failure();
429 
430     // Check if the first operation is a cast operation, if it is we use the
431     // results directly.
432     auto *defOp = operands[0].getDefiningOp();
433     if (auto packerOp =
434             llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
435       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
436       return success();
437     }
438 
439     // Otherwise, fail to match.
440     return failure();
441   }
442 };
443 
444 //===----------------------------------------------------------------------===//
445 // Multi-Level Type-Conversion Rewrite Testing
446 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
447   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
448       : ConversionPattern("test.type_producer", 1, ctx) {}
449   LogicalResult
450   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
451                   ConversionPatternRewriter &rewriter) const final {
452     // If the type is I32, change the type to F32.
453     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
454       return failure();
455     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
456     return success();
457   }
458 };
459 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
460   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
461       : ConversionPattern("test.type_producer", 1, ctx) {}
462   LogicalResult
463   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
464                   ConversionPatternRewriter &rewriter) const final {
465     // If the type is F32, change the type to F64.
466     if (!Type(*op->result_type_begin()).isF32())
467       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
468     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
469     return success();
470   }
471 };
472 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
473   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
474       : ConversionPattern("test.type_producer", 10, ctx) {}
475   LogicalResult
476   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
477                   ConversionPatternRewriter &rewriter) const final {
478     // Always convert to B16, even though it is not a legal type. This tests
479     // that values are unmapped correctly.
480     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
481     return success();
482   }
483 };
484 struct TestUpdateConsumerType : public ConversionPattern {
485   TestUpdateConsumerType(MLIRContext *ctx)
486       : ConversionPattern("test.type_consumer", 1, ctx) {}
487   LogicalResult
488   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
489                   ConversionPatternRewriter &rewriter) const final {
490     // Verify that the incoming operand has been successfully remapped to F64.
491     if (!operands[0].getType().isF64())
492       return failure();
493     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
494     return success();
495   }
496 };
497 
498 //===----------------------------------------------------------------------===//
499 // Non-Root Replacement Rewrite Testing
500 /// This pattern generates an invalid operation, but replaces it before the
501 /// pattern is finished. This checks that we don't need to legalize the
502 /// temporary op.
503 struct TestNonRootReplacement : public RewritePattern {
504   TestNonRootReplacement(MLIRContext *ctx)
505       : RewritePattern("test.replace_non_root", 1, ctx) {}
506 
507   LogicalResult matchAndRewrite(Operation *op,
508                                 PatternRewriter &rewriter) const final {
509     auto resultType = *op->result_type_begin();
510     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
511     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
512 
513     rewriter.replaceOp(illegalOp, {legalOp});
514     rewriter.replaceOp(op, {illegalOp});
515     return success();
516   }
517 };
518 
519 //===----------------------------------------------------------------------===//
520 // Recursive Rewrite Testing
521 /// This pattern is applied to the same operation multiple times, but has a
522 /// bounded recursion.
523 struct TestBoundedRecursiveRewrite
524     : public OpRewritePattern<TestRecursiveRewriteOp> {
525   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
526 
527   void initialize() {
528     // The conversion target handles bounding the recursion of this pattern.
529     setHasBoundedRewriteRecursion();
530   }
531 
532   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
533                                 PatternRewriter &rewriter) const final {
534     // Decrement the depth of the op in-place.
535     rewriter.updateRootInPlace(op, [&] {
536       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
537     });
538     return success();
539   }
540 };
541 
542 struct TestNestedOpCreationUndoRewrite
543     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
544   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
545 
546   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
547                                 PatternRewriter &rewriter) const final {
548     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
549     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
550     return success();
551   };
552 };
553 
554 // This pattern matches `test.blackhole` and delete this op and its producer.
555 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
556   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
557 
558   LogicalResult matchAndRewrite(BlackHoleOp op,
559                                 PatternRewriter &rewriter) const final {
560     Operation *producer = op.getOperand().getDefiningOp();
561     // Always erase the user before the producer, the framework should handle
562     // this correctly.
563     rewriter.eraseOp(op);
564     rewriter.eraseOp(producer);
565     return success();
566   };
567 };
568 
569 // This pattern replaces explicitly illegal op with explicitly legal op,
570 // but in addition creates unregistered operation.
571 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
572   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
573 
574   LogicalResult matchAndRewrite(ILLegalOpG op,
575                                 PatternRewriter &rewriter) const final {
576     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
577     Value val = rewriter.create<ConstantOp>(op->getLoc(), attr);
578     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
579     return success();
580   };
581 };
582 } // namespace
583 
584 namespace {
585 struct TestTypeConverter : public TypeConverter {
586   using TypeConverter::TypeConverter;
587   TestTypeConverter() {
588     addConversion(convertType);
589     addArgumentMaterialization(materializeCast);
590     addSourceMaterialization(materializeCast);
591   }
592 
593   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
594     // Drop I16 types.
595     if (t.isSignlessInteger(16))
596       return success();
597 
598     // Convert I64 to F64.
599     if (t.isSignlessInteger(64)) {
600       results.push_back(FloatType::getF64(t.getContext()));
601       return success();
602     }
603 
604     // Convert I42 to I43.
605     if (t.isInteger(42)) {
606       results.push_back(IntegerType::get(t.getContext(), 43));
607       return success();
608     }
609 
610     // Split F32 into F16,F16.
611     if (t.isF32()) {
612       results.assign(2, FloatType::getF16(t.getContext()));
613       return success();
614     }
615 
616     // Otherwise, convert the type directly.
617     results.push_back(t);
618     return success();
619   }
620 
621   /// Hook for materializing a conversion. This is necessary because we generate
622   /// 1->N type mappings.
623   static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
624                                          ValueRange inputs, Location loc) {
625     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
626   }
627 };
628 
629 struct TestLegalizePatternDriver
630     : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
631   StringRef getArgument() const final { return "test-legalize-patterns"; }
632   StringRef getDescription() const final {
633     return "Run test dialect legalization patterns";
634   }
635   /// The mode of conversion to use with the driver.
636   enum class ConversionMode { Analysis, Full, Partial };
637 
638   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
639 
640   void getDependentDialects(DialectRegistry &registry) const override {
641     registry.insert<StandardOpsDialect>();
642   }
643 
644   void runOnOperation() override {
645     TestTypeConverter converter;
646     mlir::RewritePatternSet patterns(&getContext());
647     populateWithGenerated(patterns);
648     patterns
649         .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
650              TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace,
651              TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType,
652              TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
653              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
654              TestNonRootReplacement, TestBoundedRecursiveRewrite,
655              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
656              TestCreateUnregisteredOp>(&getContext());
657     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
658     mlir::populateFuncOpTypeConversionPattern(patterns, converter);
659     mlir::populateCallOpTypeConversionPattern(patterns, converter);
660 
661     // Define the conversion target used for the test.
662     ConversionTarget target(getContext());
663     target.addLegalOp<ModuleOp>();
664     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
665                       TerminatorOp>();
666     target
667         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
668     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
669       // Don't allow F32 operands.
670       return llvm::none_of(op.getOperandTypes(),
671                            [](Type type) { return type.isF32(); });
672     });
673     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
674       return converter.isSignatureLegal(op.getType()) &&
675              converter.isLegal(&op.getBody());
676     });
677     target.addDynamicallyLegalOp<CallOp>(
678         [&](CallOp op) { return converter.isLegal(op); });
679 
680     // TestCreateUnregisteredOp creates `arith.constant` operation,
681     // which was not added to target intentionally to test
682     // correct error code from conversion driver.
683     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
684 
685     // Expect the type_producer/type_consumer operations to only operate on f64.
686     target.addDynamicallyLegalOp<TestTypeProducerOp>(
687         [](TestTypeProducerOp op) { return op.getType().isF64(); });
688     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
689       return op.getOperand().getType().isF64();
690     });
691 
692     // Check support for marking certain operations as recursively legal.
693     target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) {
694       return static_cast<bool>(
695           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
696     });
697 
698     // Mark the bound recursion operation as dynamically legal.
699     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
700         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
701 
702     // Handle a partial conversion.
703     if (mode == ConversionMode::Partial) {
704       DenseSet<Operation *> unlegalizedOps;
705       if (failed(applyPartialConversion(
706               getOperation(), target, std::move(patterns), &unlegalizedOps))) {
707         getOperation()->emitRemark() << "applyPartialConversion failed";
708       }
709       // Emit remarks for each legalizable operation.
710       for (auto *op : unlegalizedOps)
711         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
712       return;
713     }
714 
715     // Handle a full conversion.
716     if (mode == ConversionMode::Full) {
717       // Check support for marking unknown operations as dynamically legal.
718       target.markUnknownOpDynamicallyLegal([](Operation *op) {
719         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
720       });
721 
722       if (failed(applyFullConversion(getOperation(), target,
723                                      std::move(patterns)))) {
724         getOperation()->emitRemark() << "applyFullConversion failed";
725       }
726       return;
727     }
728 
729     // Otherwise, handle an analysis conversion.
730     assert(mode == ConversionMode::Analysis);
731 
732     // Analyze the convertible operations.
733     DenseSet<Operation *> legalizedOps;
734     if (failed(applyAnalysisConversion(getOperation(), target,
735                                        std::move(patterns), legalizedOps)))
736       return signalPassFailure();
737 
738     // Emit remarks for each legalizable operation.
739     for (auto *op : legalizedOps)
740       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
741   }
742 
743   /// The mode of conversion to use.
744   ConversionMode mode;
745 };
746 } // namespace
747 
748 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
749     legalizerConversionMode(
750         "test-legalize-mode",
751         llvm::cl::desc("The legalization mode to use with the test driver"),
752         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
753         llvm::cl::values(
754             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
755                        "analysis", "Perform an analysis conversion"),
756             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
757                        "Perform a full conversion"),
758             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
759                        "partial", "Perform a partial conversion")));
760 
761 //===----------------------------------------------------------------------===//
762 // ConversionPatternRewriter::getRemappedValue testing. This method is used
763 // to get the remapped value of an original value that was replaced using
764 // ConversionPatternRewriter.
765 namespace {
766 struct TestRemapValueTypeConverter : public TypeConverter {
767   using TypeConverter::TypeConverter;
768 
769   TestRemapValueTypeConverter() {
770     addConversion(
771         [](Float32Type type) { return Float64Type::get(type.getContext()); });
772     addConversion([](Type type) { return type; });
773   }
774 };
775 
776 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
777 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
778 /// operand twice.
779 ///
780 /// Example:
781 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
782 /// is replaced with:
783 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
784 struct OneVResOneVOperandOp1Converter
785     : public OpConversionPattern<OneVResOneVOperandOp1> {
786   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
787 
788   LogicalResult
789   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
790                   ConversionPatternRewriter &rewriter) const override {
791     auto origOps = op.getOperands();
792     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
793            "One operand expected");
794     Value origOp = *origOps.begin();
795     SmallVector<Value, 2> remappedOperands;
796     // Replicate the remapped original operand twice. Note that we don't used
797     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
798     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
799     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
800 
801     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
802                                                        remappedOperands);
803     return success();
804   }
805 };
806 
807 /// A rewriter pattern that tests that blocks can be merged.
808 struct TestRemapValueInRegion
809     : public OpConversionPattern<TestRemappedValueRegionOp> {
810   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
811 
812   LogicalResult
813   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
814                   ConversionPatternRewriter &rewriter) const final {
815     Block &block = op.getBody().front();
816     Operation *terminator = block.getTerminator();
817 
818     // Merge the block into the parent region.
819     Block *parentBlock = op->getBlock();
820     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
821     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
822     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
823 
824     // Replace the results of this operation with the remapped terminator
825     // values.
826     SmallVector<Value> terminatorOperands;
827     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
828                                           terminatorOperands)))
829       return failure();
830 
831     rewriter.eraseOp(terminator);
832     rewriter.replaceOp(op, terminatorOperands);
833     return success();
834   }
835 };
836 
837 struct TestRemappedValue
838     : public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
839   StringRef getArgument() const final { return "test-remapped-value"; }
840   StringRef getDescription() const final {
841     return "Test public remapped value mechanism in ConversionPatternRewriter";
842   }
843   void runOnFunction() override {
844     TestRemapValueTypeConverter typeConverter;
845 
846     mlir::RewritePatternSet patterns(&getContext());
847     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
848     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
849         &getContext());
850     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
851 
852     mlir::ConversionTarget target(getContext());
853     target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>();
854 
855     // Expect the type_producer/type_consumer operations to only operate on f64.
856     target.addDynamicallyLegalOp<TestTypeProducerOp>(
857         [](TestTypeProducerOp op) { return op.getType().isF64(); });
858     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
859       return op.getOperand().getType().isF64();
860     });
861 
862     // We make OneVResOneVOperandOp1 legal only when it has more that one
863     // operand. This will trigger the conversion that will replace one-operand
864     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
865     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
866         [](Operation *op) { return op->getNumOperands() > 1; });
867 
868     if (failed(mlir::applyFullConversion(getFunction(), target,
869                                          std::move(patterns)))) {
870       signalPassFailure();
871     }
872   }
873 };
874 } // namespace
875 
876 //===----------------------------------------------------------------------===//
877 // Test patterns without a specific root operation kind
878 //===----------------------------------------------------------------------===//
879 
880 namespace {
881 /// This pattern matches and removes any operation in the test dialect.
882 struct RemoveTestDialectOps : public RewritePattern {
883   RemoveTestDialectOps(MLIRContext *context)
884       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
885 
886   LogicalResult matchAndRewrite(Operation *op,
887                                 PatternRewriter &rewriter) const override {
888     if (!isa<TestDialect>(op->getDialect()))
889       return failure();
890     rewriter.eraseOp(op);
891     return success();
892   }
893 };
894 
895 struct TestUnknownRootOpDriver
896     : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
897   StringRef getArgument() const final {
898     return "test-legalize-unknown-root-patterns";
899   }
900   StringRef getDescription() const final {
901     return "Test public remapped value mechanism in ConversionPatternRewriter";
902   }
903   void runOnFunction() override {
904     mlir::RewritePatternSet patterns(&getContext());
905     patterns.add<RemoveTestDialectOps>(&getContext());
906 
907     mlir::ConversionTarget target(getContext());
908     target.addIllegalDialect<TestDialect>();
909     if (failed(
910             applyPartialConversion(getFunction(), target, std::move(patterns))))
911       signalPassFailure();
912   }
913 };
914 } // namespace
915 
916 //===----------------------------------------------------------------------===//
917 // Test type conversions
918 //===----------------------------------------------------------------------===//
919 
920 namespace {
921 struct TestTypeConversionProducer
922     : public OpConversionPattern<TestTypeProducerOp> {
923   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
924   LogicalResult
925   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
926                   ConversionPatternRewriter &rewriter) const final {
927     Type resultType = op.getType();
928     Type convertedType = getTypeConverter()
929                              ? getTypeConverter()->convertType(resultType)
930                              : resultType;
931     if (resultType.isa<FloatType>())
932       resultType = rewriter.getF64Type();
933     else if (resultType.isInteger(16))
934       resultType = rewriter.getIntegerType(64);
935     else if (resultType.isa<test::TestRecursiveType>() &&
936              convertedType != resultType)
937       resultType = convertedType;
938     else
939       return failure();
940 
941     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
942     return success();
943   }
944 };
945 
946 /// Call signature conversion and then fail the rewrite to trigger the undo
947 /// mechanism.
948 struct TestSignatureConversionUndo
949     : public OpConversionPattern<TestSignatureConversionUndoOp> {
950   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
951 
952   LogicalResult
953   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
954                   ConversionPatternRewriter &rewriter) const final {
955     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
956     return failure();
957   }
958 };
959 
960 /// Call signature conversion without providing a type converter to handle
961 /// materializations.
962 struct TestTestSignatureConversionNoConverter
963     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
964   TestTestSignatureConversionNoConverter(TypeConverter &converter,
965                                          MLIRContext *context)
966       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
967         converter(converter) {}
968 
969   LogicalResult
970   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
971                   ConversionPatternRewriter &rewriter) const final {
972     Region &region = op->getRegion(0);
973     Block *entry = &region.front();
974 
975     // Convert the original entry arguments.
976     TypeConverter::SignatureConversion result(entry->getNumArguments());
977     if (failed(
978             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
979       return failure();
980     rewriter.updateRootInPlace(
981         op, [&] { rewriter.applySignatureConversion(&region, result); });
982     return success();
983   }
984 
985   TypeConverter &converter;
986 };
987 
988 /// Just forward the operands to the root op. This is essentially a no-op
989 /// pattern that is used to trigger target materialization.
990 struct TestTypeConsumerForward
991     : public OpConversionPattern<TestTypeConsumerOp> {
992   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
993 
994   LogicalResult
995   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
996                   ConversionPatternRewriter &rewriter) const final {
997     rewriter.updateRootInPlace(op,
998                                [&] { op->setOperands(adaptor.getOperands()); });
999     return success();
1000   }
1001 };
1002 
1003 struct TestTypeConversionAnotherProducer
1004     : public OpRewritePattern<TestAnotherTypeProducerOp> {
1005   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1006 
1007   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1008                                 PatternRewriter &rewriter) const final {
1009     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1010     return success();
1011   }
1012 };
1013 
1014 struct TestTypeConversionDriver
1015     : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
1016   void getDependentDialects(DialectRegistry &registry) const override {
1017     registry.insert<TestDialect>();
1018   }
1019   StringRef getArgument() const final {
1020     return "test-legalize-type-conversion";
1021   }
1022   StringRef getDescription() const final {
1023     return "Test various type conversion functionalities in DialectConversion";
1024   }
1025 
1026   void runOnOperation() override {
1027     // Initialize the type converter.
1028     TypeConverter converter;
1029 
1030     /// Add the legal set of type conversions.
1031     converter.addConversion([](Type type) -> Type {
1032       // Treat F64 as legal.
1033       if (type.isF64())
1034         return type;
1035       // Allow converting BF16/F16/F32 to F64.
1036       if (type.isBF16() || type.isF16() || type.isF32())
1037         return FloatType::getF64(type.getContext());
1038       // Otherwise, the type is illegal.
1039       return nullptr;
1040     });
1041     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1042       // Drop all integer types.
1043       return success();
1044     });
1045     converter.addConversion(
1046         // Convert a recursive self-referring type into a non-self-referring
1047         // type named "outer_converted_type" that contains a SimpleAType.
1048         [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results,
1049             ArrayRef<Type> callStack) -> Optional<LogicalResult> {
1050           // If the type is already converted, return it to indicate that it is
1051           // legal.
1052           if (type.getName() == "outer_converted_type") {
1053             results.push_back(type);
1054             return success();
1055           }
1056 
1057           // If the type is on the call stack more than once (it is there at
1058           // least once because of the _current_ call, which is always the last
1059           // element on the stack), we've hit the recursive case. Just return
1060           // SimpleAType here to create a non-recursive type as a result.
1061           if (llvm::is_contained(callStack.drop_back(), type)) {
1062             results.push_back(test::SimpleAType::get(type.getContext()));
1063             return success();
1064           }
1065 
1066           // Convert the body recursively.
1067           auto result = test::TestRecursiveType::get(type.getContext(),
1068                                                      "outer_converted_type");
1069           if (failed(result.setBody(converter.convertType(type.getBody()))))
1070             return failure();
1071           results.push_back(result);
1072           return success();
1073         });
1074 
1075     /// Add the legal set of type materializations.
1076     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1077                                           ValueRange inputs,
1078                                           Location loc) -> Value {
1079       // Allow casting from F64 back to F32.
1080       if (!resultType.isF16() && inputs.size() == 1 &&
1081           inputs[0].getType().isF64())
1082         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1083       // Allow producing an i32 or i64 from nothing.
1084       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1085           inputs.empty())
1086         return builder.create<TestTypeProducerOp>(loc, resultType);
1087       // Allow producing an i64 from an integer.
1088       if (resultType.isa<IntegerType>() && inputs.size() == 1 &&
1089           inputs[0].getType().isa<IntegerType>())
1090         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1091       // Otherwise, fail.
1092       return nullptr;
1093     });
1094 
1095     // Initialize the conversion target.
1096     mlir::ConversionTarget target(getContext());
1097     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1098       auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>();
1099       return op.getType().isF64() || op.getType().isInteger(64) ||
1100              (recursiveType &&
1101               recursiveType.getName() == "outer_converted_type");
1102     });
1103     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
1104       return converter.isSignatureLegal(op.getType()) &&
1105              converter.isLegal(&op.getBody());
1106     });
1107     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1108       // Allow casts from F64 to F32.
1109       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1110     });
1111     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1112         [&](TestSignatureConversionNoConverterOp op) {
1113           return converter.isLegal(op.getRegion().front().getArgumentTypes());
1114         });
1115 
1116     // Initialize the set of rewrite patterns.
1117     RewritePatternSet patterns(&getContext());
1118     patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1119                  TestSignatureConversionUndo,
1120                  TestTestSignatureConversionNoConverter>(converter,
1121                                                          &getContext());
1122     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
1123     mlir::populateFuncOpTypeConversionPattern(patterns, converter);
1124 
1125     if (failed(applyPartialConversion(getOperation(), target,
1126                                       std::move(patterns))))
1127       signalPassFailure();
1128   }
1129 };
1130 } // namespace
1131 
1132 //===----------------------------------------------------------------------===//
1133 // Test Block Merging
1134 //===----------------------------------------------------------------------===//
1135 
1136 namespace {
1137 /// A rewriter pattern that tests that blocks can be merged.
1138 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1139   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1140 
1141   LogicalResult
1142   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1143                   ConversionPatternRewriter &rewriter) const final {
1144     Block &firstBlock = op.getBody().front();
1145     Operation *branchOp = firstBlock.getTerminator();
1146     Block *secondBlock = &*(std::next(op.getBody().begin()));
1147     auto succOperands = branchOp->getOperands();
1148     SmallVector<Value, 2> replacements(succOperands);
1149     rewriter.eraseOp(branchOp);
1150     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1151     rewriter.updateRootInPlace(op, [] {});
1152     return success();
1153   }
1154 };
1155 
1156 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1157 struct TestUndoBlocksMerge : public ConversionPattern {
1158   TestUndoBlocksMerge(MLIRContext *ctx)
1159       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1160   LogicalResult
1161   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1162                   ConversionPatternRewriter &rewriter) const final {
1163     Block &firstBlock = op->getRegion(0).front();
1164     Operation *branchOp = firstBlock.getTerminator();
1165     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1166     rewriter.setInsertionPointToStart(secondBlock);
1167     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1168     auto succOperands = branchOp->getOperands();
1169     SmallVector<Value, 2> replacements(succOperands);
1170     rewriter.eraseOp(branchOp);
1171     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1172     rewriter.updateRootInPlace(op, [] {});
1173     return success();
1174   }
1175 };
1176 
1177 /// A rewrite mechanism to inline the body of the op into its parent, when both
1178 /// ops can have a single block.
1179 struct TestMergeSingleBlockOps
1180     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1181   using OpConversionPattern<
1182       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1183 
1184   LogicalResult
1185   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1186                   ConversionPatternRewriter &rewriter) const final {
1187     SingleBlockImplicitTerminatorOp parentOp =
1188         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1189     if (!parentOp)
1190       return failure();
1191     Block &innerBlock = op.getRegion().front();
1192     TerminatorOp innerTerminator =
1193         cast<TerminatorOp>(innerBlock.getTerminator());
1194     rewriter.mergeBlockBefore(&innerBlock, op);
1195     rewriter.eraseOp(innerTerminator);
1196     rewriter.eraseOp(op);
1197     rewriter.updateRootInPlace(op, [] {});
1198     return success();
1199   }
1200 };
1201 
1202 struct TestMergeBlocksPatternDriver
1203     : public PassWrapper<TestMergeBlocksPatternDriver,
1204                          OperationPass<ModuleOp>> {
1205   StringRef getArgument() const final { return "test-merge-blocks"; }
1206   StringRef getDescription() const final {
1207     return "Test Merging operation in ConversionPatternRewriter";
1208   }
1209   void runOnOperation() override {
1210     MLIRContext *context = &getContext();
1211     mlir::RewritePatternSet patterns(context);
1212     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1213         context);
1214     ConversionTarget target(*context);
1215     target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1216                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1217     target.addIllegalOp<ILLegalOpF>();
1218 
1219     /// Expect the op to have a single block after legalization.
1220     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1221         [&](TestMergeBlocksOp op) -> bool {
1222           return llvm::hasSingleElement(op.getBody());
1223         });
1224 
1225     /// Only allow `test.br` within test.merge_blocks op.
1226     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
1227       return op->getParentOfType<TestMergeBlocksOp>();
1228     });
1229 
1230     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1231     /// inlined.
1232     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1233         [&](SingleBlockImplicitTerminatorOp op) -> bool {
1234           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1235         });
1236 
1237     DenseSet<Operation *> unlegalizedOps;
1238     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1239                                  &unlegalizedOps);
1240     for (auto *op : unlegalizedOps)
1241       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1242   }
1243 };
1244 } // namespace
1245 
1246 //===----------------------------------------------------------------------===//
1247 // Test Selective Replacement
1248 //===----------------------------------------------------------------------===//
1249 
1250 namespace {
1251 /// A rewrite mechanism to inline the body of the op into its parent, when both
1252 /// ops can have a single block.
1253 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
1254   using OpRewritePattern<TestCastOp>::OpRewritePattern;
1255 
1256   LogicalResult matchAndRewrite(TestCastOp op,
1257                                 PatternRewriter &rewriter) const final {
1258     if (op.getNumOperands() != 2)
1259       return failure();
1260     OperandRange operands = op.getOperands();
1261 
1262     // Replace non-terminator uses with the first operand.
1263     rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
1264       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
1265     });
1266     // Replace everything else with the second operand if the operation isn't
1267     // dead.
1268     rewriter.replaceOp(op, op.getOperand(1));
1269     return success();
1270   }
1271 };
1272 
1273 struct TestSelectiveReplacementPatternDriver
1274     : public PassWrapper<TestSelectiveReplacementPatternDriver,
1275                          OperationPass<>> {
1276   StringRef getArgument() const final {
1277     return "test-pattern-selective-replacement";
1278   }
1279   StringRef getDescription() const final {
1280     return "Test selective replacement in the PatternRewriter";
1281   }
1282   void runOnOperation() override {
1283     MLIRContext *context = &getContext();
1284     mlir::RewritePatternSet patterns(context);
1285     patterns.add<TestSelectiveOpReplacementPattern>(context);
1286     (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
1287                                        std::move(patterns));
1288   }
1289 };
1290 } // namespace
1291 
1292 //===----------------------------------------------------------------------===//
1293 // PassRegistration
1294 //===----------------------------------------------------------------------===//
1295 
1296 namespace mlir {
1297 namespace test {
1298 void registerPatternsTestPass() {
1299   PassRegistration<TestReturnTypeDriver>();
1300 
1301   PassRegistration<TestDerivedAttributeDriver>();
1302 
1303   PassRegistration<TestPatternDriver>();
1304 
1305   PassRegistration<TestLegalizePatternDriver>([] {
1306     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
1307   });
1308 
1309   PassRegistration<TestRemappedValue>();
1310 
1311   PassRegistration<TestUnknownRootOpDriver>();
1312 
1313   PassRegistration<TestTypeConversionDriver>();
1314 
1315   PassRegistration<TestMergeBlocksPatternDriver>();
1316   PassRegistration<TestSelectiveReplacementPatternDriver>();
1317 }
1318 } // namespace test
1319 } // namespace mlir
1320