1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/StandardOps/IR/Ops.h"
13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "mlir/Transforms/FoldUtils.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace test;
23 
24 // Native function for testing NativeCodeCall
25 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
26   return choice.getValue() ? input1 : input2;
27 }
28 
29 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
30   rewriter.create<OpI>(loc, input);
31 }
32 
33 static void handleNoResultOp(PatternRewriter &rewriter,
34                              OpSymbolBindingNoResult op) {
35   // Turn the no result op to a one-result op.
36   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
37                                     op.getOperand());
38 }
39 
40 static bool getFirstI32Result(Operation *op, Value &value) {
41   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
42     return false;
43   value = op->getResult(0);
44   return true;
45 }
46 
47 static Value bindNativeCodeCallResult(Value value) { return value; }
48 
49 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
50                                                               Value input2) {
51   return SmallVector<Value, 2>({input2, input1});
52 }
53 
54 // Test that natives calls are only called once during rewrites.
55 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
56 // This let us check the number of times OpM_Test was called by inspecting
57 // the returned value in the MLIR output.
58 static int64_t opMIncreasingValue = 314159265;
59 static Attribute opMTest(PatternRewriter &rewriter, Value val) {
60   int64_t i = opMIncreasingValue++;
61   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
62 }
63 
64 namespace {
65 #include "TestPatterns.inc"
66 } // namespace
67 
68 //===----------------------------------------------------------------------===//
69 // Test Reduce Pattern Interface
70 //===----------------------------------------------------------------------===//
71 
72 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
73   populateWithGenerated(patterns);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Canonicalizer Driver.
78 //===----------------------------------------------------------------------===//
79 
80 namespace {
81 struct FoldingPattern : public RewritePattern {
82 public:
83   FoldingPattern(MLIRContext *context)
84       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
85                        /*benefit=*/1, context) {}
86 
87   LogicalResult matchAndRewrite(Operation *op,
88                                 PatternRewriter &rewriter) const override {
89     // Exercise OperationFolder API for a single-result operation that is folded
90     // upon construction. The operation being created through the folder has an
91     // in-place folder, and it should be still present in the output.
92     // Furthermore, the folder should not crash when attempting to recover the
93     // (unchanged) operation result.
94     OperationFolder folder(op->getContext());
95     Value result = folder.create<TestOpInPlaceFold>(
96         rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
97         rewriter.getI32IntegerAttr(0));
98     assert(result);
99     rewriter.replaceOp(op, result);
100     return success();
101   }
102 };
103 
104 /// This pattern creates a foldable operation at the entry point of the block.
105 /// This tests the situation where the operation folder will need to replace an
106 /// operation with a previously created constant that does not initially
107 /// dominate the operation to replace.
108 struct FolderInsertBeforePreviouslyFoldedConstantPattern
109     : public OpRewritePattern<TestCastOp> {
110 public:
111   using OpRewritePattern<TestCastOp>::OpRewritePattern;
112 
113   LogicalResult matchAndRewrite(TestCastOp op,
114                                 PatternRewriter &rewriter) const override {
115     if (!op->hasAttr("test_fold_before_previously_folded_op"))
116       return failure();
117     rewriter.setInsertionPointToStart(op->getBlock());
118 
119     auto constOp = rewriter.create<arith::ConstantOp>(
120         op.getLoc(), rewriter.getBoolAttr(true));
121     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
122                                             Value(constOp));
123     return success();
124   }
125 };
126 
127 struct TestPatternDriver
128     : public PassWrapper<TestPatternDriver, OperationPass<FuncOp>> {
129   StringRef getArgument() const final { return "test-patterns"; }
130   StringRef getDescription() const final { return "Run test dialect patterns"; }
131   void runOnOperation() override {
132     mlir::RewritePatternSet patterns(&getContext());
133     populateWithGenerated(patterns);
134 
135     // Verify named pattern is generated with expected name.
136     patterns.add<FoldingPattern, TestNamedPatternRule,
137                  FolderInsertBeforePreviouslyFoldedConstantPattern>(
138         &getContext());
139 
140     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
141   }
142 };
143 } // namespace
144 
145 //===----------------------------------------------------------------------===//
146 // ReturnType Driver.
147 //===----------------------------------------------------------------------===//
148 
149 namespace {
150 // Generate ops for each instance where the type can be successfully inferred.
151 template <typename OpTy>
152 static void invokeCreateWithInferredReturnType(Operation *op) {
153   auto *context = op->getContext();
154   auto fop = op->getParentOfType<FuncOp>();
155   auto location = UnknownLoc::get(context);
156   OpBuilder b(op);
157   b.setInsertionPointAfter(op);
158 
159   // Use permutations of 2 args as operands.
160   assert(fop.getNumArguments() >= 2);
161   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
162     for (int j = 0; j < e; ++j) {
163       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
164       SmallVector<Type, 2> inferredReturnTypes;
165       if (succeeded(OpTy::inferReturnTypes(
166               context, llvm::None, values, op->getAttrDictionary(),
167               op->getRegions(), inferredReturnTypes))) {
168         OperationState state(location, OpTy::getOperationName());
169         // TODO: Expand to regions.
170         OpTy::build(b, state, values, op->getAttrs());
171         (void)b.createOperation(state);
172       }
173     }
174   }
175 }
176 
177 static void reifyReturnShape(Operation *op) {
178   OpBuilder b(op);
179 
180   // Use permutations of 2 args as operands.
181   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
182   SmallVector<Value, 2> shapes;
183   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
184       !llvm::hasSingleElement(shapes))
185     return;
186   for (const auto &it : llvm::enumerate(shapes)) {
187     op->emitRemark() << "value " << it.index() << ": "
188                      << it.value().getDefiningOp();
189   }
190 }
191 
192 struct TestReturnTypeDriver
193     : public PassWrapper<TestReturnTypeDriver, OperationPass<FuncOp>> {
194   void getDependentDialects(DialectRegistry &registry) const override {
195     registry.insert<tensor::TensorDialect>();
196   }
197   StringRef getArgument() const final { return "test-return-type"; }
198   StringRef getDescription() const final { return "Run return type functions"; }
199 
200   void runOnOperation() override {
201     if (getOperation().getName() == "testCreateFunctions") {
202       std::vector<Operation *> ops;
203       // Collect ops to avoid triggering on inserted ops.
204       for (auto &op : getOperation().getBody().front())
205         ops.push_back(&op);
206       // Generate test patterns for each, but skip terminator.
207       for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
208         // Test create method of each of the Op classes below. The resultant
209         // output would be in reverse order underneath `op` from which
210         // the attributes and regions are used.
211         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
212         invokeCreateWithInferredReturnType<
213             OpWithShapedTypeInferTypeInterfaceOp>(op);
214       };
215       return;
216     }
217     if (getOperation().getName() == "testReifyFunctions") {
218       std::vector<Operation *> ops;
219       // Collect ops to avoid triggering on inserted ops.
220       for (auto &op : getOperation().getBody().front())
221         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
222           ops.push_back(&op);
223       // Generate test patterns for each, but skip terminator.
224       for (auto *op : ops)
225         reifyReturnShape(op);
226     }
227   }
228 };
229 } // namespace
230 
231 namespace {
232 struct TestDerivedAttributeDriver
233     : public PassWrapper<TestDerivedAttributeDriver, OperationPass<FuncOp>> {
234   StringRef getArgument() const final { return "test-derived-attr"; }
235   StringRef getDescription() const final {
236     return "Run test derived attributes";
237   }
238   void runOnOperation() override;
239 };
240 } // namespace
241 
242 void TestDerivedAttributeDriver::runOnOperation() {
243   getOperation().walk([](DerivedAttributeOpInterface dOp) {
244     auto dAttr = dOp.materializeDerivedAttributes();
245     if (!dAttr)
246       return;
247     for (auto d : dAttr)
248       dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
249   });
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // Legalization Driver.
254 //===----------------------------------------------------------------------===//
255 
256 namespace {
257 //===----------------------------------------------------------------------===//
258 // Region-Block Rewrite Testing
259 
260 /// This pattern is a simple pattern that inlines the first region of a given
261 /// operation into the parent region.
262 struct TestRegionRewriteBlockMovement : public ConversionPattern {
263   TestRegionRewriteBlockMovement(MLIRContext *ctx)
264       : ConversionPattern("test.region", 1, ctx) {}
265 
266   LogicalResult
267   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
268                   ConversionPatternRewriter &rewriter) const final {
269     // Inline this region into the parent region.
270     auto &parentRegion = *op->getParentRegion();
271     auto &opRegion = op->getRegion(0);
272     if (op->getAttr("legalizer.should_clone"))
273       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
274     else
275       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
276 
277     if (op->getAttr("legalizer.erase_old_blocks")) {
278       while (!opRegion.empty())
279         rewriter.eraseBlock(&opRegion.front());
280     }
281 
282     // Drop this operation.
283     rewriter.eraseOp(op);
284     return success();
285   }
286 };
287 /// This pattern is a simple pattern that generates a region containing an
288 /// illegal operation.
289 struct TestRegionRewriteUndo : public RewritePattern {
290   TestRegionRewriteUndo(MLIRContext *ctx)
291       : RewritePattern("test.region_builder", 1, ctx) {}
292 
293   LogicalResult matchAndRewrite(Operation *op,
294                                 PatternRewriter &rewriter) const final {
295     // Create the region operation with an entry block containing arguments.
296     OperationState newRegion(op->getLoc(), "test.region");
297     newRegion.addRegion();
298     auto *regionOp = rewriter.createOperation(newRegion);
299     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
300     entryBlock->addArgument(rewriter.getIntegerType(64),
301                             rewriter.getUnknownLoc());
302 
303     // Add an explicitly illegal operation to ensure the conversion fails.
304     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
305     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
306 
307     // Drop this operation.
308     rewriter.eraseOp(op);
309     return success();
310   }
311 };
312 /// A simple pattern that creates a block at the end of the parent region of the
313 /// matched operation.
314 struct TestCreateBlock : public RewritePattern {
315   TestCreateBlock(MLIRContext *ctx)
316       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
317 
318   LogicalResult matchAndRewrite(Operation *op,
319                                 PatternRewriter &rewriter) const final {
320     Region &region = *op->getParentRegion();
321     Type i32Type = rewriter.getIntegerType(32);
322     Location loc = op->getLoc();
323     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
324     rewriter.create<TerminatorOp>(loc);
325     rewriter.replaceOp(op, {});
326     return success();
327   }
328 };
329 
330 /// A simple pattern that creates a block containing an invalid operation in
331 /// order to trigger the block creation undo mechanism.
332 struct TestCreateIllegalBlock : public RewritePattern {
333   TestCreateIllegalBlock(MLIRContext *ctx)
334       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
335 
336   LogicalResult matchAndRewrite(Operation *op,
337                                 PatternRewriter &rewriter) const final {
338     Region &region = *op->getParentRegion();
339     Type i32Type = rewriter.getIntegerType(32);
340     Location loc = op->getLoc();
341     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
342     // Create an illegal op to ensure the conversion fails.
343     rewriter.create<ILLegalOpF>(loc, i32Type);
344     rewriter.create<TerminatorOp>(loc);
345     rewriter.replaceOp(op, {});
346     return success();
347   }
348 };
349 
350 /// A simple pattern that tests the undo mechanism when replacing the uses of a
351 /// block argument.
352 struct TestUndoBlockArgReplace : public ConversionPattern {
353   TestUndoBlockArgReplace(MLIRContext *ctx)
354       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
355 
356   LogicalResult
357   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
358                   ConversionPatternRewriter &rewriter) const final {
359     auto illegalOp =
360         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
361     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
362                                         illegalOp);
363     rewriter.updateRootInPlace(op, [] {});
364     return success();
365   }
366 };
367 
368 /// A rewrite pattern that tests the undo mechanism when erasing a block.
369 struct TestUndoBlockErase : public ConversionPattern {
370   TestUndoBlockErase(MLIRContext *ctx)
371       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
372 
373   LogicalResult
374   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
375                   ConversionPatternRewriter &rewriter) const final {
376     Block *secondBlock = &*std::next(op->getRegion(0).begin());
377     rewriter.setInsertionPointToStart(secondBlock);
378     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
379     rewriter.eraseBlock(secondBlock);
380     rewriter.updateRootInPlace(op, [] {});
381     return success();
382   }
383 };
384 
385 //===----------------------------------------------------------------------===//
386 // Type-Conversion Rewrite Testing
387 
388 /// This patterns erases a region operation that has had a type conversion.
389 struct TestDropOpSignatureConversion : public ConversionPattern {
390   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
391       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
392   LogicalResult
393   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
394                   ConversionPatternRewriter &rewriter) const override {
395     Region &region = op->getRegion(0);
396     Block *entry = &region.front();
397 
398     // Convert the original entry arguments.
399     TypeConverter &converter = *getTypeConverter();
400     TypeConverter::SignatureConversion result(entry->getNumArguments());
401     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
402                                               result)) ||
403         failed(rewriter.convertRegionTypes(&region, converter, &result)))
404       return failure();
405 
406     // Convert the region signature and just drop the operation.
407     rewriter.eraseOp(op);
408     return success();
409   }
410 };
411 /// This pattern simply updates the operands of the given operation.
412 struct TestPassthroughInvalidOp : public ConversionPattern {
413   TestPassthroughInvalidOp(MLIRContext *ctx)
414       : ConversionPattern("test.invalid", 1, ctx) {}
415   LogicalResult
416   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
417                   ConversionPatternRewriter &rewriter) const final {
418     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
419                                              llvm::None);
420     return success();
421   }
422 };
423 /// This pattern handles the case of a split return value.
424 struct TestSplitReturnType : public ConversionPattern {
425   TestSplitReturnType(MLIRContext *ctx)
426       : ConversionPattern("test.return", 1, ctx) {}
427   LogicalResult
428   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
429                   ConversionPatternRewriter &rewriter) const final {
430     // Check for a return of F32.
431     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
432       return failure();
433 
434     // Check if the first operation is a cast operation, if it is we use the
435     // results directly.
436     auto *defOp = operands[0].getDefiningOp();
437     if (auto packerOp =
438             llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
439       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
440       return success();
441     }
442 
443     // Otherwise, fail to match.
444     return failure();
445   }
446 };
447 
448 //===----------------------------------------------------------------------===//
449 // Multi-Level Type-Conversion Rewrite Testing
450 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
451   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
452       : ConversionPattern("test.type_producer", 1, ctx) {}
453   LogicalResult
454   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
455                   ConversionPatternRewriter &rewriter) const final {
456     // If the type is I32, change the type to F32.
457     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
458       return failure();
459     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
460     return success();
461   }
462 };
463 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
464   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
465       : ConversionPattern("test.type_producer", 1, ctx) {}
466   LogicalResult
467   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
468                   ConversionPatternRewriter &rewriter) const final {
469     // If the type is F32, change the type to F64.
470     if (!Type(*op->result_type_begin()).isF32())
471       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
472     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
473     return success();
474   }
475 };
476 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
477   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
478       : ConversionPattern("test.type_producer", 10, ctx) {}
479   LogicalResult
480   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
481                   ConversionPatternRewriter &rewriter) const final {
482     // Always convert to B16, even though it is not a legal type. This tests
483     // that values are unmapped correctly.
484     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
485     return success();
486   }
487 };
488 struct TestUpdateConsumerType : public ConversionPattern {
489   TestUpdateConsumerType(MLIRContext *ctx)
490       : ConversionPattern("test.type_consumer", 1, ctx) {}
491   LogicalResult
492   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
493                   ConversionPatternRewriter &rewriter) const final {
494     // Verify that the incoming operand has been successfully remapped to F64.
495     if (!operands[0].getType().isF64())
496       return failure();
497     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
498     return success();
499   }
500 };
501 
502 //===----------------------------------------------------------------------===//
503 // Non-Root Replacement Rewrite Testing
504 /// This pattern generates an invalid operation, but replaces it before the
505 /// pattern is finished. This checks that we don't need to legalize the
506 /// temporary op.
507 struct TestNonRootReplacement : public RewritePattern {
508   TestNonRootReplacement(MLIRContext *ctx)
509       : RewritePattern("test.replace_non_root", 1, ctx) {}
510 
511   LogicalResult matchAndRewrite(Operation *op,
512                                 PatternRewriter &rewriter) const final {
513     auto resultType = *op->result_type_begin();
514     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
515     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
516 
517     rewriter.replaceOp(illegalOp, {legalOp});
518     rewriter.replaceOp(op, {illegalOp});
519     return success();
520   }
521 };
522 
523 //===----------------------------------------------------------------------===//
524 // Recursive Rewrite Testing
525 /// This pattern is applied to the same operation multiple times, but has a
526 /// bounded recursion.
527 struct TestBoundedRecursiveRewrite
528     : public OpRewritePattern<TestRecursiveRewriteOp> {
529   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
530 
531   void initialize() {
532     // The conversion target handles bounding the recursion of this pattern.
533     setHasBoundedRewriteRecursion();
534   }
535 
536   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
537                                 PatternRewriter &rewriter) const final {
538     // Decrement the depth of the op in-place.
539     rewriter.updateRootInPlace(op, [&] {
540       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
541     });
542     return success();
543   }
544 };
545 
546 struct TestNestedOpCreationUndoRewrite
547     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
548   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
549 
550   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
551                                 PatternRewriter &rewriter) const final {
552     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
553     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
554     return success();
555   };
556 };
557 
558 // This pattern matches `test.blackhole` and delete this op and its producer.
559 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
560   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
561 
562   LogicalResult matchAndRewrite(BlackHoleOp op,
563                                 PatternRewriter &rewriter) const final {
564     Operation *producer = op.getOperand().getDefiningOp();
565     // Always erase the user before the producer, the framework should handle
566     // this correctly.
567     rewriter.eraseOp(op);
568     rewriter.eraseOp(producer);
569     return success();
570   };
571 };
572 
573 // This pattern replaces explicitly illegal op with explicitly legal op,
574 // but in addition creates unregistered operation.
575 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
576   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
577 
578   LogicalResult matchAndRewrite(ILLegalOpG op,
579                                 PatternRewriter &rewriter) const final {
580     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
581     Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
582     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
583     return success();
584   };
585 };
586 } // namespace
587 
588 namespace {
589 struct TestTypeConverter : public TypeConverter {
590   using TypeConverter::TypeConverter;
591   TestTypeConverter() {
592     addConversion(convertType);
593     addArgumentMaterialization(materializeCast);
594     addSourceMaterialization(materializeCast);
595   }
596 
597   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
598     // Drop I16 types.
599     if (t.isSignlessInteger(16))
600       return success();
601 
602     // Convert I64 to F64.
603     if (t.isSignlessInteger(64)) {
604       results.push_back(FloatType::getF64(t.getContext()));
605       return success();
606     }
607 
608     // Convert I42 to I43.
609     if (t.isInteger(42)) {
610       results.push_back(IntegerType::get(t.getContext(), 43));
611       return success();
612     }
613 
614     // Split F32 into F16,F16.
615     if (t.isF32()) {
616       results.assign(2, FloatType::getF16(t.getContext()));
617       return success();
618     }
619 
620     // Otherwise, convert the type directly.
621     results.push_back(t);
622     return success();
623   }
624 
625   /// Hook for materializing a conversion. This is necessary because we generate
626   /// 1->N type mappings.
627   static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
628                                          ValueRange inputs, Location loc) {
629     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
630   }
631 };
632 
633 struct TestLegalizePatternDriver
634     : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
635   StringRef getArgument() const final { return "test-legalize-patterns"; }
636   StringRef getDescription() const final {
637     return "Run test dialect legalization patterns";
638   }
639   /// The mode of conversion to use with the driver.
640   enum class ConversionMode { Analysis, Full, Partial };
641 
642   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
643 
644   void getDependentDialects(DialectRegistry &registry) const override {
645     registry.insert<StandardOpsDialect>();
646   }
647 
648   void runOnOperation() override {
649     TestTypeConverter converter;
650     mlir::RewritePatternSet patterns(&getContext());
651     populateWithGenerated(patterns);
652     patterns
653         .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
654              TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace,
655              TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType,
656              TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
657              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
658              TestNonRootReplacement, TestBoundedRecursiveRewrite,
659              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
660              TestCreateUnregisteredOp>(&getContext());
661     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
662     mlir::populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
663                                                                    converter);
664     mlir::populateCallOpTypeConversionPattern(patterns, converter);
665 
666     // Define the conversion target used for the test.
667     ConversionTarget target(getContext());
668     target.addLegalOp<ModuleOp>();
669     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
670                       TerminatorOp>();
671     target
672         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
673     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
674       // Don't allow F32 operands.
675       return llvm::none_of(op.getOperandTypes(),
676                            [](Type type) { return type.isF32(); });
677     });
678     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
679       return converter.isSignatureLegal(op.getType()) &&
680              converter.isLegal(&op.getBody());
681     });
682     target.addDynamicallyLegalOp<CallOp>(
683         [&](CallOp op) { return converter.isLegal(op); });
684 
685     // TestCreateUnregisteredOp creates `arith.constant` operation,
686     // which was not added to target intentionally to test
687     // correct error code from conversion driver.
688     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
689 
690     // Expect the type_producer/type_consumer operations to only operate on f64.
691     target.addDynamicallyLegalOp<TestTypeProducerOp>(
692         [](TestTypeProducerOp op) { return op.getType().isF64(); });
693     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
694       return op.getOperand().getType().isF64();
695     });
696 
697     // Check support for marking certain operations as recursively legal.
698     target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) {
699       return static_cast<bool>(
700           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
701     });
702 
703     // Mark the bound recursion operation as dynamically legal.
704     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
705         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
706 
707     // Handle a partial conversion.
708     if (mode == ConversionMode::Partial) {
709       DenseSet<Operation *> unlegalizedOps;
710       if (failed(applyPartialConversion(
711               getOperation(), target, std::move(patterns), &unlegalizedOps))) {
712         getOperation()->emitRemark() << "applyPartialConversion failed";
713       }
714       // Emit remarks for each legalizable operation.
715       for (auto *op : unlegalizedOps)
716         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
717       return;
718     }
719 
720     // Handle a full conversion.
721     if (mode == ConversionMode::Full) {
722       // Check support for marking unknown operations as dynamically legal.
723       target.markUnknownOpDynamicallyLegal([](Operation *op) {
724         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
725       });
726 
727       if (failed(applyFullConversion(getOperation(), target,
728                                      std::move(patterns)))) {
729         getOperation()->emitRemark() << "applyFullConversion failed";
730       }
731       return;
732     }
733 
734     // Otherwise, handle an analysis conversion.
735     assert(mode == ConversionMode::Analysis);
736 
737     // Analyze the convertible operations.
738     DenseSet<Operation *> legalizedOps;
739     if (failed(applyAnalysisConversion(getOperation(), target,
740                                        std::move(patterns), legalizedOps)))
741       return signalPassFailure();
742 
743     // Emit remarks for each legalizable operation.
744     for (auto *op : legalizedOps)
745       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
746   }
747 
748   /// The mode of conversion to use.
749   ConversionMode mode;
750 };
751 } // namespace
752 
753 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
754     legalizerConversionMode(
755         "test-legalize-mode",
756         llvm::cl::desc("The legalization mode to use with the test driver"),
757         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
758         llvm::cl::values(
759             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
760                        "analysis", "Perform an analysis conversion"),
761             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
762                        "Perform a full conversion"),
763             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
764                        "partial", "Perform a partial conversion")));
765 
766 //===----------------------------------------------------------------------===//
767 // ConversionPatternRewriter::getRemappedValue testing. This method is used
768 // to get the remapped value of an original value that was replaced using
769 // ConversionPatternRewriter.
770 namespace {
771 struct TestRemapValueTypeConverter : public TypeConverter {
772   using TypeConverter::TypeConverter;
773 
774   TestRemapValueTypeConverter() {
775     addConversion(
776         [](Float32Type type) { return Float64Type::get(type.getContext()); });
777     addConversion([](Type type) { return type; });
778   }
779 };
780 
781 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
782 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
783 /// operand twice.
784 ///
785 /// Example:
786 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
787 /// is replaced with:
788 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
789 struct OneVResOneVOperandOp1Converter
790     : public OpConversionPattern<OneVResOneVOperandOp1> {
791   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
792 
793   LogicalResult
794   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
795                   ConversionPatternRewriter &rewriter) const override {
796     auto origOps = op.getOperands();
797     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
798            "One operand expected");
799     Value origOp = *origOps.begin();
800     SmallVector<Value, 2> remappedOperands;
801     // Replicate the remapped original operand twice. Note that we don't used
802     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
803     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
804     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
805 
806     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
807                                                        remappedOperands);
808     return success();
809   }
810 };
811 
812 /// A rewriter pattern that tests that blocks can be merged.
813 struct TestRemapValueInRegion
814     : public OpConversionPattern<TestRemappedValueRegionOp> {
815   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
816 
817   LogicalResult
818   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
819                   ConversionPatternRewriter &rewriter) const final {
820     Block &block = op.getBody().front();
821     Operation *terminator = block.getTerminator();
822 
823     // Merge the block into the parent region.
824     Block *parentBlock = op->getBlock();
825     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
826     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
827     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
828 
829     // Replace the results of this operation with the remapped terminator
830     // values.
831     SmallVector<Value> terminatorOperands;
832     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
833                                           terminatorOperands)))
834       return failure();
835 
836     rewriter.eraseOp(terminator);
837     rewriter.replaceOp(op, terminatorOperands);
838     return success();
839   }
840 };
841 
842 struct TestRemappedValue
843     : public mlir::PassWrapper<TestRemappedValue, OperationPass<FuncOp>> {
844   StringRef getArgument() const final { return "test-remapped-value"; }
845   StringRef getDescription() const final {
846     return "Test public remapped value mechanism in ConversionPatternRewriter";
847   }
848   void runOnOperation() override {
849     TestRemapValueTypeConverter typeConverter;
850 
851     mlir::RewritePatternSet patterns(&getContext());
852     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
853     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
854         &getContext());
855     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
856 
857     mlir::ConversionTarget target(getContext());
858     target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>();
859 
860     // Expect the type_producer/type_consumer operations to only operate on f64.
861     target.addDynamicallyLegalOp<TestTypeProducerOp>(
862         [](TestTypeProducerOp op) { return op.getType().isF64(); });
863     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
864       return op.getOperand().getType().isF64();
865     });
866 
867     // We make OneVResOneVOperandOp1 legal only when it has more that one
868     // operand. This will trigger the conversion that will replace one-operand
869     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
870     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
871         [](Operation *op) { return op->getNumOperands() > 1; });
872 
873     if (failed(mlir::applyFullConversion(getOperation(), target,
874                                          std::move(patterns)))) {
875       signalPassFailure();
876     }
877   }
878 };
879 } // namespace
880 
881 //===----------------------------------------------------------------------===//
882 // Test patterns without a specific root operation kind
883 //===----------------------------------------------------------------------===//
884 
885 namespace {
886 /// This pattern matches and removes any operation in the test dialect.
887 struct RemoveTestDialectOps : public RewritePattern {
888   RemoveTestDialectOps(MLIRContext *context)
889       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
890 
891   LogicalResult matchAndRewrite(Operation *op,
892                                 PatternRewriter &rewriter) const override {
893     if (!isa<TestDialect>(op->getDialect()))
894       return failure();
895     rewriter.eraseOp(op);
896     return success();
897   }
898 };
899 
900 struct TestUnknownRootOpDriver
901     : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<FuncOp>> {
902   StringRef getArgument() const final {
903     return "test-legalize-unknown-root-patterns";
904   }
905   StringRef getDescription() const final {
906     return "Test public remapped value mechanism in ConversionPatternRewriter";
907   }
908   void runOnOperation() override {
909     mlir::RewritePatternSet patterns(&getContext());
910     patterns.add<RemoveTestDialectOps>(&getContext());
911 
912     mlir::ConversionTarget target(getContext());
913     target.addIllegalDialect<TestDialect>();
914     if (failed(applyPartialConversion(getOperation(), target,
915                                       std::move(patterns))))
916       signalPassFailure();
917   }
918 };
919 } // namespace
920 
921 //===----------------------------------------------------------------------===//
922 // Test type conversions
923 //===----------------------------------------------------------------------===//
924 
925 namespace {
926 struct TestTypeConversionProducer
927     : public OpConversionPattern<TestTypeProducerOp> {
928   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
929   LogicalResult
930   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
931                   ConversionPatternRewriter &rewriter) const final {
932     Type resultType = op.getType();
933     Type convertedType = getTypeConverter()
934                              ? getTypeConverter()->convertType(resultType)
935                              : resultType;
936     if (resultType.isa<FloatType>())
937       resultType = rewriter.getF64Type();
938     else if (resultType.isInteger(16))
939       resultType = rewriter.getIntegerType(64);
940     else if (resultType.isa<test::TestRecursiveType>() &&
941              convertedType != resultType)
942       resultType = convertedType;
943     else
944       return failure();
945 
946     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
947     return success();
948   }
949 };
950 
951 /// Call signature conversion and then fail the rewrite to trigger the undo
952 /// mechanism.
953 struct TestSignatureConversionUndo
954     : public OpConversionPattern<TestSignatureConversionUndoOp> {
955   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
956 
957   LogicalResult
958   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
959                   ConversionPatternRewriter &rewriter) const final {
960     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
961     return failure();
962   }
963 };
964 
965 /// Call signature conversion without providing a type converter to handle
966 /// materializations.
967 struct TestTestSignatureConversionNoConverter
968     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
969   TestTestSignatureConversionNoConverter(TypeConverter &converter,
970                                          MLIRContext *context)
971       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
972         converter(converter) {}
973 
974   LogicalResult
975   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
976                   ConversionPatternRewriter &rewriter) const final {
977     Region &region = op->getRegion(0);
978     Block *entry = &region.front();
979 
980     // Convert the original entry arguments.
981     TypeConverter::SignatureConversion result(entry->getNumArguments());
982     if (failed(
983             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
984       return failure();
985     rewriter.updateRootInPlace(
986         op, [&] { rewriter.applySignatureConversion(&region, result); });
987     return success();
988   }
989 
990   TypeConverter &converter;
991 };
992 
993 /// Just forward the operands to the root op. This is essentially a no-op
994 /// pattern that is used to trigger target materialization.
995 struct TestTypeConsumerForward
996     : public OpConversionPattern<TestTypeConsumerOp> {
997   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
998 
999   LogicalResult
1000   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1001                   ConversionPatternRewriter &rewriter) const final {
1002     rewriter.updateRootInPlace(op,
1003                                [&] { op->setOperands(adaptor.getOperands()); });
1004     return success();
1005   }
1006 };
1007 
1008 struct TestTypeConversionAnotherProducer
1009     : public OpRewritePattern<TestAnotherTypeProducerOp> {
1010   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1011 
1012   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1013                                 PatternRewriter &rewriter) const final {
1014     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1015     return success();
1016   }
1017 };
1018 
1019 struct TestTypeConversionDriver
1020     : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
1021   void getDependentDialects(DialectRegistry &registry) const override {
1022     registry.insert<TestDialect>();
1023   }
1024   StringRef getArgument() const final {
1025     return "test-legalize-type-conversion";
1026   }
1027   StringRef getDescription() const final {
1028     return "Test various type conversion functionalities in DialectConversion";
1029   }
1030 
1031   void runOnOperation() override {
1032     // Initialize the type converter.
1033     TypeConverter converter;
1034 
1035     /// Add the legal set of type conversions.
1036     converter.addConversion([](Type type) -> Type {
1037       // Treat F64 as legal.
1038       if (type.isF64())
1039         return type;
1040       // Allow converting BF16/F16/F32 to F64.
1041       if (type.isBF16() || type.isF16() || type.isF32())
1042         return FloatType::getF64(type.getContext());
1043       // Otherwise, the type is illegal.
1044       return nullptr;
1045     });
1046     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1047       // Drop all integer types.
1048       return success();
1049     });
1050     converter.addConversion(
1051         // Convert a recursive self-referring type into a non-self-referring
1052         // type named "outer_converted_type" that contains a SimpleAType.
1053         [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results,
1054             ArrayRef<Type> callStack) -> Optional<LogicalResult> {
1055           // If the type is already converted, return it to indicate that it is
1056           // legal.
1057           if (type.getName() == "outer_converted_type") {
1058             results.push_back(type);
1059             return success();
1060           }
1061 
1062           // If the type is on the call stack more than once (it is there at
1063           // least once because of the _current_ call, which is always the last
1064           // element on the stack), we've hit the recursive case. Just return
1065           // SimpleAType here to create a non-recursive type as a result.
1066           if (llvm::is_contained(callStack.drop_back(), type)) {
1067             results.push_back(test::SimpleAType::get(type.getContext()));
1068             return success();
1069           }
1070 
1071           // Convert the body recursively.
1072           auto result = test::TestRecursiveType::get(type.getContext(),
1073                                                      "outer_converted_type");
1074           if (failed(result.setBody(converter.convertType(type.getBody()))))
1075             return failure();
1076           results.push_back(result);
1077           return success();
1078         });
1079 
1080     /// Add the legal set of type materializations.
1081     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1082                                           ValueRange inputs,
1083                                           Location loc) -> Value {
1084       // Allow casting from F64 back to F32.
1085       if (!resultType.isF16() && inputs.size() == 1 &&
1086           inputs[0].getType().isF64())
1087         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1088       // Allow producing an i32 or i64 from nothing.
1089       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1090           inputs.empty())
1091         return builder.create<TestTypeProducerOp>(loc, resultType);
1092       // Allow producing an i64 from an integer.
1093       if (resultType.isa<IntegerType>() && inputs.size() == 1 &&
1094           inputs[0].getType().isa<IntegerType>())
1095         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1096       // Otherwise, fail.
1097       return nullptr;
1098     });
1099 
1100     // Initialize the conversion target.
1101     mlir::ConversionTarget target(getContext());
1102     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1103       auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>();
1104       return op.getType().isF64() || op.getType().isInteger(64) ||
1105              (recursiveType &&
1106               recursiveType.getName() == "outer_converted_type");
1107     });
1108     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
1109       return converter.isSignatureLegal(op.getType()) &&
1110              converter.isLegal(&op.getBody());
1111     });
1112     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1113       // Allow casts from F64 to F32.
1114       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1115     });
1116     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1117         [&](TestSignatureConversionNoConverterOp op) {
1118           return converter.isLegal(op.getRegion().front().getArgumentTypes());
1119         });
1120 
1121     // Initialize the set of rewrite patterns.
1122     RewritePatternSet patterns(&getContext());
1123     patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1124                  TestSignatureConversionUndo,
1125                  TestTestSignatureConversionNoConverter>(converter,
1126                                                          &getContext());
1127     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
1128     mlir::populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
1129                                                                    converter);
1130 
1131     if (failed(applyPartialConversion(getOperation(), target,
1132                                       std::move(patterns))))
1133       signalPassFailure();
1134   }
1135 };
1136 } // namespace
1137 
1138 //===----------------------------------------------------------------------===//
1139 // Test Block Merging
1140 //===----------------------------------------------------------------------===//
1141 
1142 namespace {
1143 /// A rewriter pattern that tests that blocks can be merged.
1144 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1145   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1146 
1147   LogicalResult
1148   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1149                   ConversionPatternRewriter &rewriter) const final {
1150     Block &firstBlock = op.getBody().front();
1151     Operation *branchOp = firstBlock.getTerminator();
1152     Block *secondBlock = &*(std::next(op.getBody().begin()));
1153     auto succOperands = branchOp->getOperands();
1154     SmallVector<Value, 2> replacements(succOperands);
1155     rewriter.eraseOp(branchOp);
1156     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1157     rewriter.updateRootInPlace(op, [] {});
1158     return success();
1159   }
1160 };
1161 
1162 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1163 struct TestUndoBlocksMerge : public ConversionPattern {
1164   TestUndoBlocksMerge(MLIRContext *ctx)
1165       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1166   LogicalResult
1167   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1168                   ConversionPatternRewriter &rewriter) const final {
1169     Block &firstBlock = op->getRegion(0).front();
1170     Operation *branchOp = firstBlock.getTerminator();
1171     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1172     rewriter.setInsertionPointToStart(secondBlock);
1173     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1174     auto succOperands = branchOp->getOperands();
1175     SmallVector<Value, 2> replacements(succOperands);
1176     rewriter.eraseOp(branchOp);
1177     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1178     rewriter.updateRootInPlace(op, [] {});
1179     return success();
1180   }
1181 };
1182 
1183 /// A rewrite mechanism to inline the body of the op into its parent, when both
1184 /// ops can have a single block.
1185 struct TestMergeSingleBlockOps
1186     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1187   using OpConversionPattern<
1188       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1189 
1190   LogicalResult
1191   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1192                   ConversionPatternRewriter &rewriter) const final {
1193     SingleBlockImplicitTerminatorOp parentOp =
1194         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1195     if (!parentOp)
1196       return failure();
1197     Block &innerBlock = op.getRegion().front();
1198     TerminatorOp innerTerminator =
1199         cast<TerminatorOp>(innerBlock.getTerminator());
1200     rewriter.mergeBlockBefore(&innerBlock, op);
1201     rewriter.eraseOp(innerTerminator);
1202     rewriter.eraseOp(op);
1203     rewriter.updateRootInPlace(op, [] {});
1204     return success();
1205   }
1206 };
1207 
1208 struct TestMergeBlocksPatternDriver
1209     : public PassWrapper<TestMergeBlocksPatternDriver,
1210                          OperationPass<ModuleOp>> {
1211   StringRef getArgument() const final { return "test-merge-blocks"; }
1212   StringRef getDescription() const final {
1213     return "Test Merging operation in ConversionPatternRewriter";
1214   }
1215   void runOnOperation() override {
1216     MLIRContext *context = &getContext();
1217     mlir::RewritePatternSet patterns(context);
1218     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1219         context);
1220     ConversionTarget target(*context);
1221     target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1222                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1223     target.addIllegalOp<ILLegalOpF>();
1224 
1225     /// Expect the op to have a single block after legalization.
1226     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1227         [&](TestMergeBlocksOp op) -> bool {
1228           return llvm::hasSingleElement(op.getBody());
1229         });
1230 
1231     /// Only allow `test.br` within test.merge_blocks op.
1232     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
1233       return op->getParentOfType<TestMergeBlocksOp>();
1234     });
1235 
1236     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1237     /// inlined.
1238     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1239         [&](SingleBlockImplicitTerminatorOp op) -> bool {
1240           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1241         });
1242 
1243     DenseSet<Operation *> unlegalizedOps;
1244     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1245                                  &unlegalizedOps);
1246     for (auto *op : unlegalizedOps)
1247       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1248   }
1249 };
1250 } // namespace
1251 
1252 //===----------------------------------------------------------------------===//
1253 // Test Selective Replacement
1254 //===----------------------------------------------------------------------===//
1255 
1256 namespace {
1257 /// A rewrite mechanism to inline the body of the op into its parent, when both
1258 /// ops can have a single block.
1259 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
1260   using OpRewritePattern<TestCastOp>::OpRewritePattern;
1261 
1262   LogicalResult matchAndRewrite(TestCastOp op,
1263                                 PatternRewriter &rewriter) const final {
1264     if (op.getNumOperands() != 2)
1265       return failure();
1266     OperandRange operands = op.getOperands();
1267 
1268     // Replace non-terminator uses with the first operand.
1269     rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
1270       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
1271     });
1272     // Replace everything else with the second operand if the operation isn't
1273     // dead.
1274     rewriter.replaceOp(op, op.getOperand(1));
1275     return success();
1276   }
1277 };
1278 
1279 struct TestSelectiveReplacementPatternDriver
1280     : public PassWrapper<TestSelectiveReplacementPatternDriver,
1281                          OperationPass<>> {
1282   StringRef getArgument() const final {
1283     return "test-pattern-selective-replacement";
1284   }
1285   StringRef getDescription() const final {
1286     return "Test selective replacement in the PatternRewriter";
1287   }
1288   void runOnOperation() override {
1289     MLIRContext *context = &getContext();
1290     mlir::RewritePatternSet patterns(context);
1291     patterns.add<TestSelectiveOpReplacementPattern>(context);
1292     (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
1293                                        std::move(patterns));
1294   }
1295 };
1296 } // namespace
1297 
1298 //===----------------------------------------------------------------------===//
1299 // PassRegistration
1300 //===----------------------------------------------------------------------===//
1301 
1302 namespace mlir {
1303 namespace test {
1304 void registerPatternsTestPass() {
1305   PassRegistration<TestReturnTypeDriver>();
1306 
1307   PassRegistration<TestDerivedAttributeDriver>();
1308 
1309   PassRegistration<TestPatternDriver>();
1310 
1311   PassRegistration<TestLegalizePatternDriver>([] {
1312     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
1313   });
1314 
1315   PassRegistration<TestRemappedValue>();
1316 
1317   PassRegistration<TestUnknownRootOpDriver>();
1318 
1319   PassRegistration<TestTypeConversionDriver>();
1320 
1321   PassRegistration<TestMergeBlocksPatternDriver>();
1322   PassRegistration<TestSelectiveReplacementPatternDriver>();
1323 }
1324 } // namespace test
1325 } // namespace mlir
1326