1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "mlir/Transforms/FoldUtils.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 
20 using namespace mlir;
21 using namespace test;
22 
23 // Native function for testing NativeCodeCall
24 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
25   return choice.getValue() ? input1 : input2;
26 }
27 
28 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
29   rewriter.create<OpI>(loc, input);
30 }
31 
32 static void handleNoResultOp(PatternRewriter &rewriter,
33                              OpSymbolBindingNoResult op) {
34   // Turn the no result op to a one-result op.
35   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(),
36                                     op.operand());
37 }
38 
39 static bool getFirstI32Result(Operation *op, Value &value) {
40   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
41     return false;
42   value = op->getResult(0);
43   return true;
44 }
45 
46 static Value bindNativeCodeCallResult(Value value) { return value; }
47 
48 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
49                                                               Value input2) {
50   return SmallVector<Value, 2>({input2, input1});
51 }
52 
53 // Test that natives calls are only called once during rewrites.
54 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
55 // This let us check the number of times OpM_Test was called by inspecting
56 // the returned value in the MLIR output.
57 static int64_t opMIncreasingValue = 314159265;
58 static Attribute OpMTest(PatternRewriter &rewriter, Value val) {
59   int64_t i = opMIncreasingValue++;
60   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
61 }
62 
63 namespace {
64 #include "TestPatterns.inc"
65 } // end anonymous namespace
66 
67 //===----------------------------------------------------------------------===//
68 // Test Reduce Pattern Interface
69 //===----------------------------------------------------------------------===//
70 
71 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
72   populateWithGenerated(patterns);
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // Canonicalizer Driver.
77 //===----------------------------------------------------------------------===//
78 
79 namespace {
80 struct FoldingPattern : public RewritePattern {
81 public:
82   FoldingPattern(MLIRContext *context)
83       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
84                        /*benefit=*/1, context) {}
85 
86   LogicalResult matchAndRewrite(Operation *op,
87                                 PatternRewriter &rewriter) const override {
88     // Exercise OperationFolder API for a single-result operation that is folded
89     // upon construction. The operation being created through the folder has an
90     // in-place folder, and it should be still present in the output.
91     // Furthermore, the folder should not crash when attempting to recover the
92     // (unchanged) operation result.
93     OperationFolder folder(op->getContext());
94     Value result = folder.create<TestOpInPlaceFold>(
95         rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
96         rewriter.getI32IntegerAttr(0));
97     assert(result);
98     rewriter.replaceOp(op, result);
99     return success();
100   }
101 };
102 
103 /// This pattern creates a foldable operation at the entry point of the block.
104 /// This tests the situation where the operation folder will need to replace an
105 /// operation with a previously created constant that does not initially
106 /// dominate the operation to replace.
107 struct FolderInsertBeforePreviouslyFoldedConstantPattern
108     : public OpRewritePattern<TestCastOp> {
109 public:
110   using OpRewritePattern<TestCastOp>::OpRewritePattern;
111 
112   LogicalResult matchAndRewrite(TestCastOp op,
113                                 PatternRewriter &rewriter) const override {
114     if (!op->hasAttr("test_fold_before_previously_folded_op"))
115       return failure();
116     rewriter.setInsertionPointToStart(op->getBlock());
117 
118     auto constOp = rewriter.create<arith::ConstantOp>(
119         op.getLoc(), rewriter.getBoolAttr(true));
120     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
121                                             Value(constOp));
122     return success();
123   }
124 };
125 
126 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
127   StringRef getArgument() const final { return "test-patterns"; }
128   StringRef getDescription() const final { return "Run test dialect patterns"; }
129   void runOnFunction() override {
130     mlir::RewritePatternSet patterns(&getContext());
131     populateWithGenerated(patterns);
132 
133     // Verify named pattern is generated with expected name.
134     patterns.add<FoldingPattern, TestNamedPatternRule,
135                  FolderInsertBeforePreviouslyFoldedConstantPattern>(
136         &getContext());
137 
138     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
139   }
140 };
141 } // end anonymous namespace
142 
143 //===----------------------------------------------------------------------===//
144 // ReturnType Driver.
145 //===----------------------------------------------------------------------===//
146 
147 namespace {
148 // Generate ops for each instance where the type can be successfully inferred.
149 template <typename OpTy>
150 static void invokeCreateWithInferredReturnType(Operation *op) {
151   auto *context = op->getContext();
152   auto fop = op->getParentOfType<FuncOp>();
153   auto location = UnknownLoc::get(context);
154   OpBuilder b(op);
155   b.setInsertionPointAfter(op);
156 
157   // Use permutations of 2 args as operands.
158   assert(fop.getNumArguments() >= 2);
159   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
160     for (int j = 0; j < e; ++j) {
161       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
162       SmallVector<Type, 2> inferredReturnTypes;
163       if (succeeded(OpTy::inferReturnTypes(
164               context, llvm::None, values, op->getAttrDictionary(),
165               op->getRegions(), inferredReturnTypes))) {
166         OperationState state(location, OpTy::getOperationName());
167         // TODO: Expand to regions.
168         OpTy::build(b, state, values, op->getAttrs());
169         (void)b.createOperation(state);
170       }
171     }
172   }
173 }
174 
175 static void reifyReturnShape(Operation *op) {
176   OpBuilder b(op);
177 
178   // Use permutations of 2 args as operands.
179   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
180   SmallVector<Value, 2> shapes;
181   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
182       !llvm::hasSingleElement(shapes))
183     return;
184   for (auto it : llvm::enumerate(shapes)) {
185     op->emitRemark() << "value " << it.index() << ": "
186                      << it.value().getDefiningOp();
187   }
188 }
189 
190 struct TestReturnTypeDriver
191     : public PassWrapper<TestReturnTypeDriver, FunctionPass> {
192   void getDependentDialects(DialectRegistry &registry) const override {
193     registry.insert<tensor::TensorDialect>();
194   }
195   StringRef getArgument() const final { return "test-return-type"; }
196   StringRef getDescription() const final { return "Run return type functions"; }
197 
198   void runOnFunction() override {
199     if (getFunction().getName() == "testCreateFunctions") {
200       std::vector<Operation *> ops;
201       // Collect ops to avoid triggering on inserted ops.
202       for (auto &op : getFunction().getBody().front())
203         ops.push_back(&op);
204       // Generate test patterns for each, but skip terminator.
205       for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
206         // Test create method of each of the Op classes below. The resultant
207         // output would be in reverse order underneath `op` from which
208         // the attributes and regions are used.
209         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
210         invokeCreateWithInferredReturnType<
211             OpWithShapedTypeInferTypeInterfaceOp>(op);
212       };
213       return;
214     }
215     if (getFunction().getName() == "testReifyFunctions") {
216       std::vector<Operation *> ops;
217       // Collect ops to avoid triggering on inserted ops.
218       for (auto &op : getFunction().getBody().front())
219         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
220           ops.push_back(&op);
221       // Generate test patterns for each, but skip terminator.
222       for (auto *op : ops)
223         reifyReturnShape(op);
224     }
225   }
226 };
227 } // end anonymous namespace
228 
229 namespace {
230 struct TestDerivedAttributeDriver
231     : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> {
232   StringRef getArgument() const final { return "test-derived-attr"; }
233   StringRef getDescription() const final {
234     return "Run test derived attributes";
235   }
236   void runOnFunction() override;
237 };
238 } // end anonymous namespace
239 
240 void TestDerivedAttributeDriver::runOnFunction() {
241   getFunction().walk([](DerivedAttributeOpInterface dOp) {
242     auto dAttr = dOp.materializeDerivedAttributes();
243     if (!dAttr)
244       return;
245     for (auto d : dAttr)
246       dOp.emitRemark() << d.first << " = " << d.second;
247   });
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // Legalization Driver.
252 //===----------------------------------------------------------------------===//
253 
254 namespace {
255 //===----------------------------------------------------------------------===//
256 // Region-Block Rewrite Testing
257 
258 /// This pattern is a simple pattern that inlines the first region of a given
259 /// operation into the parent region.
260 struct TestRegionRewriteBlockMovement : public ConversionPattern {
261   TestRegionRewriteBlockMovement(MLIRContext *ctx)
262       : ConversionPattern("test.region", 1, ctx) {}
263 
264   LogicalResult
265   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
266                   ConversionPatternRewriter &rewriter) const final {
267     // Inline this region into the parent region.
268     auto &parentRegion = *op->getParentRegion();
269     auto &opRegion = op->getRegion(0);
270     if (op->getAttr("legalizer.should_clone"))
271       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
272     else
273       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
274 
275     if (op->getAttr("legalizer.erase_old_blocks")) {
276       while (!opRegion.empty())
277         rewriter.eraseBlock(&opRegion.front());
278     }
279 
280     // Drop this operation.
281     rewriter.eraseOp(op);
282     return success();
283   }
284 };
285 /// This pattern is a simple pattern that generates a region containing an
286 /// illegal operation.
287 struct TestRegionRewriteUndo : public RewritePattern {
288   TestRegionRewriteUndo(MLIRContext *ctx)
289       : RewritePattern("test.region_builder", 1, ctx) {}
290 
291   LogicalResult matchAndRewrite(Operation *op,
292                                 PatternRewriter &rewriter) const final {
293     // Create the region operation with an entry block containing arguments.
294     OperationState newRegion(op->getLoc(), "test.region");
295     newRegion.addRegion();
296     auto *regionOp = rewriter.createOperation(newRegion);
297     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
298     entryBlock->addArgument(rewriter.getIntegerType(64));
299 
300     // Add an explicitly illegal operation to ensure the conversion fails.
301     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
302     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
303 
304     // Drop this operation.
305     rewriter.eraseOp(op);
306     return success();
307   }
308 };
309 /// A simple pattern that creates a block at the end of the parent region of the
310 /// matched operation.
311 struct TestCreateBlock : public RewritePattern {
312   TestCreateBlock(MLIRContext *ctx)
313       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
314 
315   LogicalResult matchAndRewrite(Operation *op,
316                                 PatternRewriter &rewriter) const final {
317     Region &region = *op->getParentRegion();
318     Type i32Type = rewriter.getIntegerType(32);
319     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
320     rewriter.create<TerminatorOp>(op->getLoc());
321     rewriter.replaceOp(op, {});
322     return success();
323   }
324 };
325 
326 /// A simple pattern that creates a block containing an invalid operation in
327 /// order to trigger the block creation undo mechanism.
328 struct TestCreateIllegalBlock : public RewritePattern {
329   TestCreateIllegalBlock(MLIRContext *ctx)
330       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
331 
332   LogicalResult matchAndRewrite(Operation *op,
333                                 PatternRewriter &rewriter) const final {
334     Region &region = *op->getParentRegion();
335     Type i32Type = rewriter.getIntegerType(32);
336     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
337     // Create an illegal op to ensure the conversion fails.
338     rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
339     rewriter.create<TerminatorOp>(op->getLoc());
340     rewriter.replaceOp(op, {});
341     return success();
342   }
343 };
344 
345 /// A simple pattern that tests the undo mechanism when replacing the uses of a
346 /// block argument.
347 struct TestUndoBlockArgReplace : public ConversionPattern {
348   TestUndoBlockArgReplace(MLIRContext *ctx)
349       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
350 
351   LogicalResult
352   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
353                   ConversionPatternRewriter &rewriter) const final {
354     auto illegalOp =
355         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
356     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
357                                         illegalOp);
358     rewriter.updateRootInPlace(op, [] {});
359     return success();
360   }
361 };
362 
363 /// A rewrite pattern that tests the undo mechanism when erasing a block.
364 struct TestUndoBlockErase : public ConversionPattern {
365   TestUndoBlockErase(MLIRContext *ctx)
366       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
367 
368   LogicalResult
369   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
370                   ConversionPatternRewriter &rewriter) const final {
371     Block *secondBlock = &*std::next(op->getRegion(0).begin());
372     rewriter.setInsertionPointToStart(secondBlock);
373     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
374     rewriter.eraseBlock(secondBlock);
375     rewriter.updateRootInPlace(op, [] {});
376     return success();
377   }
378 };
379 
380 //===----------------------------------------------------------------------===//
381 // Type-Conversion Rewrite Testing
382 
383 /// This patterns erases a region operation that has had a type conversion.
384 struct TestDropOpSignatureConversion : public ConversionPattern {
385   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
386       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
387   LogicalResult
388   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
389                   ConversionPatternRewriter &rewriter) const override {
390     Region &region = op->getRegion(0);
391     Block *entry = &region.front();
392 
393     // Convert the original entry arguments.
394     TypeConverter &converter = *getTypeConverter();
395     TypeConverter::SignatureConversion result(entry->getNumArguments());
396     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
397                                               result)) ||
398         failed(rewriter.convertRegionTypes(&region, converter, &result)))
399       return failure();
400 
401     // Convert the region signature and just drop the operation.
402     rewriter.eraseOp(op);
403     return success();
404   }
405 };
406 /// This pattern simply updates the operands of the given operation.
407 struct TestPassthroughInvalidOp : public ConversionPattern {
408   TestPassthroughInvalidOp(MLIRContext *ctx)
409       : ConversionPattern("test.invalid", 1, ctx) {}
410   LogicalResult
411   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
412                   ConversionPatternRewriter &rewriter) const final {
413     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
414                                              llvm::None);
415     return success();
416   }
417 };
418 /// This pattern handles the case of a split return value.
419 struct TestSplitReturnType : public ConversionPattern {
420   TestSplitReturnType(MLIRContext *ctx)
421       : ConversionPattern("test.return", 1, ctx) {}
422   LogicalResult
423   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
424                   ConversionPatternRewriter &rewriter) const final {
425     // Check for a return of F32.
426     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
427       return failure();
428 
429     // Check if the first operation is a cast operation, if it is we use the
430     // results directly.
431     auto *defOp = operands[0].getDefiningOp();
432     if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
433       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
434       return success();
435     }
436 
437     // Otherwise, fail to match.
438     return failure();
439   }
440 };
441 
442 //===----------------------------------------------------------------------===//
443 // Multi-Level Type-Conversion Rewrite Testing
444 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
445   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
446       : ConversionPattern("test.type_producer", 1, ctx) {}
447   LogicalResult
448   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
449                   ConversionPatternRewriter &rewriter) const final {
450     // If the type is I32, change the type to F32.
451     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
452       return failure();
453     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
454     return success();
455   }
456 };
457 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
458   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
459       : ConversionPattern("test.type_producer", 1, ctx) {}
460   LogicalResult
461   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
462                   ConversionPatternRewriter &rewriter) const final {
463     // If the type is F32, change the type to F64.
464     if (!Type(*op->result_type_begin()).isF32())
465       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
466     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
467     return success();
468   }
469 };
470 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
471   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
472       : ConversionPattern("test.type_producer", 10, ctx) {}
473   LogicalResult
474   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
475                   ConversionPatternRewriter &rewriter) const final {
476     // Always convert to B16, even though it is not a legal type. This tests
477     // that values are unmapped correctly.
478     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
479     return success();
480   }
481 };
482 struct TestUpdateConsumerType : public ConversionPattern {
483   TestUpdateConsumerType(MLIRContext *ctx)
484       : ConversionPattern("test.type_consumer", 1, ctx) {}
485   LogicalResult
486   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
487                   ConversionPatternRewriter &rewriter) const final {
488     // Verify that the incoming operand has been successfully remapped to F64.
489     if (!operands[0].getType().isF64())
490       return failure();
491     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
492     return success();
493   }
494 };
495 
496 //===----------------------------------------------------------------------===//
497 // Non-Root Replacement Rewrite Testing
498 /// This pattern generates an invalid operation, but replaces it before the
499 /// pattern is finished. This checks that we don't need to legalize the
500 /// temporary op.
501 struct TestNonRootReplacement : public RewritePattern {
502   TestNonRootReplacement(MLIRContext *ctx)
503       : RewritePattern("test.replace_non_root", 1, ctx) {}
504 
505   LogicalResult matchAndRewrite(Operation *op,
506                                 PatternRewriter &rewriter) const final {
507     auto resultType = *op->result_type_begin();
508     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
509     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
510 
511     rewriter.replaceOp(illegalOp, {legalOp});
512     rewriter.replaceOp(op, {illegalOp});
513     return success();
514   }
515 };
516 
517 //===----------------------------------------------------------------------===//
518 // Recursive Rewrite Testing
519 /// This pattern is applied to the same operation multiple times, but has a
520 /// bounded recursion.
521 struct TestBoundedRecursiveRewrite
522     : public OpRewritePattern<TestRecursiveRewriteOp> {
523   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
524 
525   void initialize() {
526     // The conversion target handles bounding the recursion of this pattern.
527     setHasBoundedRewriteRecursion();
528   }
529 
530   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
531                                 PatternRewriter &rewriter) const final {
532     // Decrement the depth of the op in-place.
533     rewriter.updateRootInPlace(op, [&] {
534       op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1));
535     });
536     return success();
537   }
538 };
539 
540 struct TestNestedOpCreationUndoRewrite
541     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
542   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
543 
544   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
545                                 PatternRewriter &rewriter) const final {
546     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
547     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
548     return success();
549   };
550 };
551 
552 // This pattern matches `test.blackhole` and delete this op and its producer.
553 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
554   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
555 
556   LogicalResult matchAndRewrite(BlackHoleOp op,
557                                 PatternRewriter &rewriter) const final {
558     Operation *producer = op.getOperand().getDefiningOp();
559     // Always erase the user before the producer, the framework should handle
560     // this correctly.
561     rewriter.eraseOp(op);
562     rewriter.eraseOp(producer);
563     return success();
564   };
565 };
566 
567 // This pattern replaces explicitly illegal op with explicitly legal op,
568 // but in addition creates unregistered operation.
569 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
570   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
571 
572   LogicalResult matchAndRewrite(ILLegalOpG op,
573                                 PatternRewriter &rewriter) const final {
574     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
575     Value val = rewriter.create<ConstantOp>(op->getLoc(), attr);
576     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
577     return success();
578   };
579 };
580 } // namespace
581 
582 namespace {
583 struct TestTypeConverter : public TypeConverter {
584   using TypeConverter::TypeConverter;
585   TestTypeConverter() {
586     addConversion(convertType);
587     addArgumentMaterialization(materializeCast);
588     addSourceMaterialization(materializeCast);
589 
590     /// Materialize the cast for one-to-one conversion from i64 to f64.
591     const auto materializeOneToOneCast =
592         [](OpBuilder &builder, IntegerType resultType, ValueRange inputs,
593            Location loc) -> Optional<Value> {
594       if (resultType.getWidth() == 42 && inputs.size() == 1)
595         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
596       return llvm::None;
597     };
598     addArgumentMaterialization(materializeOneToOneCast);
599   }
600 
601   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
602     // Drop I16 types.
603     if (t.isSignlessInteger(16))
604       return success();
605 
606     // Convert I64 to F64.
607     if (t.isSignlessInteger(64)) {
608       results.push_back(FloatType::getF64(t.getContext()));
609       return success();
610     }
611 
612     // Convert I42 to I43.
613     if (t.isInteger(42)) {
614       results.push_back(IntegerType::get(t.getContext(), 43));
615       return success();
616     }
617 
618     // Split F32 into F16,F16.
619     if (t.isF32()) {
620       results.assign(2, FloatType::getF16(t.getContext()));
621       return success();
622     }
623 
624     // Otherwise, convert the type directly.
625     results.push_back(t);
626     return success();
627   }
628 
629   /// Hook for materializing a conversion. This is necessary because we generate
630   /// 1->N type mappings.
631   static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
632                                          ValueRange inputs, Location loc) {
633     if (inputs.size() == 1)
634       return inputs[0];
635     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
636   }
637 };
638 
639 struct TestLegalizePatternDriver
640     : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
641   StringRef getArgument() const final { return "test-legalize-patterns"; }
642   StringRef getDescription() const final {
643     return "Run test dialect legalization patterns";
644   }
645   /// The mode of conversion to use with the driver.
646   enum class ConversionMode { Analysis, Full, Partial };
647 
648   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
649 
650   void getDependentDialects(DialectRegistry &registry) const override {
651     registry.insert<StandardOpsDialect>();
652   }
653 
654   void runOnOperation() override {
655     TestTypeConverter converter;
656     mlir::RewritePatternSet patterns(&getContext());
657     populateWithGenerated(patterns);
658     patterns
659         .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
660              TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace,
661              TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType,
662              TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
663              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
664              TestNonRootReplacement, TestBoundedRecursiveRewrite,
665              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
666              TestCreateUnregisteredOp>(&getContext());
667     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
668     mlir::populateFuncOpTypeConversionPattern(patterns, converter);
669     mlir::populateCallOpTypeConversionPattern(patterns, converter);
670 
671     // Define the conversion target used for the test.
672     ConversionTarget target(getContext());
673     target.addLegalOp<ModuleOp>();
674     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
675                       TerminatorOp>();
676     target
677         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
678     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
679       // Don't allow F32 operands.
680       return llvm::none_of(op.getOperandTypes(),
681                            [](Type type) { return type.isF32(); });
682     });
683     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
684       return converter.isSignatureLegal(op.getType()) &&
685              converter.isLegal(&op.getBody());
686     });
687 
688     // TestCreateUnregisteredOp creates `arith.constant` operation,
689     // which was not added to target intentionally to test
690     // correct error code from conversion driver.
691     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
692 
693     // Expect the type_producer/type_consumer operations to only operate on f64.
694     target.addDynamicallyLegalOp<TestTypeProducerOp>(
695         [](TestTypeProducerOp op) { return op.getType().isF64(); });
696     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
697       return op.getOperand().getType().isF64();
698     });
699 
700     // Check support for marking certain operations as recursively legal.
701     target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) {
702       return static_cast<bool>(
703           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
704     });
705 
706     // Mark the bound recursion operation as dynamically legal.
707     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
708         [](TestRecursiveRewriteOp op) { return op.depth() == 0; });
709 
710     // Handle a partial conversion.
711     if (mode == ConversionMode::Partial) {
712       DenseSet<Operation *> unlegalizedOps;
713       if (failed(applyPartialConversion(
714               getOperation(), target, std::move(patterns), &unlegalizedOps))) {
715         getOperation()->emitRemark() << "applyPartialConversion failed";
716       }
717       // Emit remarks for each legalizable operation.
718       for (auto *op : unlegalizedOps)
719         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
720       return;
721     }
722 
723     // Handle a full conversion.
724     if (mode == ConversionMode::Full) {
725       // Check support for marking unknown operations as dynamically legal.
726       target.markUnknownOpDynamicallyLegal([](Operation *op) {
727         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
728       });
729 
730       if (failed(applyFullConversion(getOperation(), target,
731                                      std::move(patterns)))) {
732         getOperation()->emitRemark() << "applyFullConversion failed";
733       }
734       return;
735     }
736 
737     // Otherwise, handle an analysis conversion.
738     assert(mode == ConversionMode::Analysis);
739 
740     // Analyze the convertible operations.
741     DenseSet<Operation *> legalizedOps;
742     if (failed(applyAnalysisConversion(getOperation(), target,
743                                        std::move(patterns), legalizedOps)))
744       return signalPassFailure();
745 
746     // Emit remarks for each legalizable operation.
747     for (auto *op : legalizedOps)
748       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
749   }
750 
751   /// The mode of conversion to use.
752   ConversionMode mode;
753 };
754 } // end anonymous namespace
755 
756 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
757     legalizerConversionMode(
758         "test-legalize-mode",
759         llvm::cl::desc("The legalization mode to use with the test driver"),
760         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
761         llvm::cl::values(
762             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
763                        "analysis", "Perform an analysis conversion"),
764             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
765                        "Perform a full conversion"),
766             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
767                        "partial", "Perform a partial conversion")));
768 
769 //===----------------------------------------------------------------------===//
770 // ConversionPatternRewriter::getRemappedValue testing. This method is used
771 // to get the remapped value of an original value that was replaced using
772 // ConversionPatternRewriter.
773 namespace {
774 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
775 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
776 /// operand twice.
777 ///
778 /// Example:
779 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
780 /// is replaced with:
781 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
782 struct OneVResOneVOperandOp1Converter
783     : public OpConversionPattern<OneVResOneVOperandOp1> {
784   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
785 
786   LogicalResult
787   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
788                   ConversionPatternRewriter &rewriter) const override {
789     auto origOps = op.getOperands();
790     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
791            "One operand expected");
792     Value origOp = *origOps.begin();
793     SmallVector<Value, 2> remappedOperands;
794     // Replicate the remapped original operand twice. Note that we don't used
795     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
796     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
797     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
798 
799     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
800                                                        remappedOperands);
801     return success();
802   }
803 };
804 
805 struct TestRemappedValue
806     : public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
807   StringRef getArgument() const final { return "test-remapped-value"; }
808   StringRef getDescription() const final {
809     return "Test public remapped value mechanism in ConversionPatternRewriter";
810   }
811   void runOnFunction() override {
812     mlir::RewritePatternSet patterns(&getContext());
813     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
814 
815     mlir::ConversionTarget target(getContext());
816     target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>();
817     // We make OneVResOneVOperandOp1 legal only when it has more that one
818     // operand. This will trigger the conversion that will replace one-operand
819     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
820     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
821         [](Operation *op) -> bool {
822           return std::distance(op->operand_begin(), op->operand_end()) > 1;
823         });
824 
825     if (failed(mlir::applyFullConversion(getFunction(), target,
826                                          std::move(patterns)))) {
827       signalPassFailure();
828     }
829   }
830 };
831 } // end anonymous namespace
832 
833 //===----------------------------------------------------------------------===//
834 // Test patterns without a specific root operation kind
835 //===----------------------------------------------------------------------===//
836 
837 namespace {
838 /// This pattern matches and removes any operation in the test dialect.
839 struct RemoveTestDialectOps : public RewritePattern {
840   RemoveTestDialectOps(MLIRContext *context)
841       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
842 
843   LogicalResult matchAndRewrite(Operation *op,
844                                 PatternRewriter &rewriter) const override {
845     if (!isa<TestDialect>(op->getDialect()))
846       return failure();
847     rewriter.eraseOp(op);
848     return success();
849   }
850 };
851 
852 struct TestUnknownRootOpDriver
853     : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
854   StringRef getArgument() const final {
855     return "test-legalize-unknown-root-patterns";
856   }
857   StringRef getDescription() const final {
858     return "Test public remapped value mechanism in ConversionPatternRewriter";
859   }
860   void runOnFunction() override {
861     mlir::RewritePatternSet patterns(&getContext());
862     patterns.add<RemoveTestDialectOps>(&getContext());
863 
864     mlir::ConversionTarget target(getContext());
865     target.addIllegalDialect<TestDialect>();
866     if (failed(
867             applyPartialConversion(getFunction(), target, std::move(patterns))))
868       signalPassFailure();
869   }
870 };
871 } // end anonymous namespace
872 
873 //===----------------------------------------------------------------------===//
874 // Test type conversions
875 //===----------------------------------------------------------------------===//
876 
877 namespace {
878 struct TestTypeConversionProducer
879     : public OpConversionPattern<TestTypeProducerOp> {
880   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
881   LogicalResult
882   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
883                   ConversionPatternRewriter &rewriter) const final {
884     Type resultType = op.getType();
885     if (resultType.isa<FloatType>())
886       resultType = rewriter.getF64Type();
887     else if (resultType.isInteger(16))
888       resultType = rewriter.getIntegerType(64);
889     else
890       return failure();
891 
892     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
893     return success();
894   }
895 };
896 
897 /// Call signature conversion and then fail the rewrite to trigger the undo
898 /// mechanism.
899 struct TestSignatureConversionUndo
900     : public OpConversionPattern<TestSignatureConversionUndoOp> {
901   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
902 
903   LogicalResult
904   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
905                   ConversionPatternRewriter &rewriter) const final {
906     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
907     return failure();
908   }
909 };
910 
911 /// Just forward the operands to the root op. This is essentially a no-op
912 /// pattern that is used to trigger target materialization.
913 struct TestTypeConsumerForward
914     : public OpConversionPattern<TestTypeConsumerOp> {
915   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
916 
917   LogicalResult
918   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
919                   ConversionPatternRewriter &rewriter) const final {
920     rewriter.updateRootInPlace(op,
921                                [&] { op->setOperands(adaptor.getOperands()); });
922     return success();
923   }
924 };
925 
926 struct TestTypeConversionAnotherProducer
927     : public OpRewritePattern<TestAnotherTypeProducerOp> {
928   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
929 
930   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
931                                 PatternRewriter &rewriter) const final {
932     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
933     return success();
934   }
935 };
936 
937 struct TestTypeConversionDriver
938     : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
939   void getDependentDialects(DialectRegistry &registry) const override {
940     registry.insert<TestDialect>();
941   }
942   StringRef getArgument() const final {
943     return "test-legalize-type-conversion";
944   }
945   StringRef getDescription() const final {
946     return "Test various type conversion functionalities in DialectConversion";
947   }
948 
949   void runOnOperation() override {
950     // Initialize the type converter.
951     TypeConverter converter;
952 
953     /// Add the legal set of type conversions.
954     converter.addConversion([](Type type) -> Type {
955       // Treat F64 as legal.
956       if (type.isF64())
957         return type;
958       // Allow converting BF16/F16/F32 to F64.
959       if (type.isBF16() || type.isF16() || type.isF32())
960         return FloatType::getF64(type.getContext());
961       // Otherwise, the type is illegal.
962       return nullptr;
963     });
964     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
965       // Drop all integer types.
966       return success();
967     });
968 
969     /// Add the legal set of type materializations.
970     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
971                                           ValueRange inputs,
972                                           Location loc) -> Value {
973       // Allow casting from F64 back to F32.
974       if (!resultType.isF16() && inputs.size() == 1 &&
975           inputs[0].getType().isF64())
976         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
977       // Allow producing an i32 or i64 from nothing.
978       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
979           inputs.empty())
980         return builder.create<TestTypeProducerOp>(loc, resultType);
981       // Allow producing an i64 from an integer.
982       if (resultType.isa<IntegerType>() && inputs.size() == 1 &&
983           inputs[0].getType().isa<IntegerType>())
984         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
985       // Otherwise, fail.
986       return nullptr;
987     });
988 
989     // Initialize the conversion target.
990     mlir::ConversionTarget target(getContext());
991     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
992       return op.getType().isF64() || op.getType().isInteger(64);
993     });
994     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
995       return converter.isSignatureLegal(op.getType()) &&
996              converter.isLegal(&op.getBody());
997     });
998     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
999       // Allow casts from F64 to F32.
1000       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1001     });
1002 
1003     // Initialize the set of rewrite patterns.
1004     RewritePatternSet patterns(&getContext());
1005     patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1006                  TestSignatureConversionUndo>(converter, &getContext());
1007     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
1008     mlir::populateFuncOpTypeConversionPattern(patterns, converter);
1009 
1010     if (failed(applyPartialConversion(getOperation(), target,
1011                                       std::move(patterns))))
1012       signalPassFailure();
1013   }
1014 };
1015 } // end anonymous namespace
1016 
1017 //===----------------------------------------------------------------------===//
1018 // Test Block Merging
1019 //===----------------------------------------------------------------------===//
1020 
1021 namespace {
1022 /// A rewriter pattern that tests that blocks can be merged.
1023 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1024   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1025 
1026   LogicalResult
1027   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1028                   ConversionPatternRewriter &rewriter) const final {
1029     Block &firstBlock = op.body().front();
1030     Operation *branchOp = firstBlock.getTerminator();
1031     Block *secondBlock = &*(std::next(op.body().begin()));
1032     auto succOperands = branchOp->getOperands();
1033     SmallVector<Value, 2> replacements(succOperands);
1034     rewriter.eraseOp(branchOp);
1035     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1036     rewriter.updateRootInPlace(op, [] {});
1037     return success();
1038   }
1039 };
1040 
1041 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1042 struct TestUndoBlocksMerge : public ConversionPattern {
1043   TestUndoBlocksMerge(MLIRContext *ctx)
1044       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1045   LogicalResult
1046   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1047                   ConversionPatternRewriter &rewriter) const final {
1048     Block &firstBlock = op->getRegion(0).front();
1049     Operation *branchOp = firstBlock.getTerminator();
1050     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1051     rewriter.setInsertionPointToStart(secondBlock);
1052     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1053     auto succOperands = branchOp->getOperands();
1054     SmallVector<Value, 2> replacements(succOperands);
1055     rewriter.eraseOp(branchOp);
1056     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1057     rewriter.updateRootInPlace(op, [] {});
1058     return success();
1059   }
1060 };
1061 
1062 /// A rewrite mechanism to inline the body of the op into its parent, when both
1063 /// ops can have a single block.
1064 struct TestMergeSingleBlockOps
1065     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1066   using OpConversionPattern<
1067       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1068 
1069   LogicalResult
1070   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1071                   ConversionPatternRewriter &rewriter) const final {
1072     SingleBlockImplicitTerminatorOp parentOp =
1073         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1074     if (!parentOp)
1075       return failure();
1076     Block &innerBlock = op.region().front();
1077     TerminatorOp innerTerminator =
1078         cast<TerminatorOp>(innerBlock.getTerminator());
1079     rewriter.mergeBlockBefore(&innerBlock, op);
1080     rewriter.eraseOp(innerTerminator);
1081     rewriter.eraseOp(op);
1082     rewriter.updateRootInPlace(op, [] {});
1083     return success();
1084   }
1085 };
1086 
1087 struct TestMergeBlocksPatternDriver
1088     : public PassWrapper<TestMergeBlocksPatternDriver,
1089                          OperationPass<ModuleOp>> {
1090   StringRef getArgument() const final { return "test-merge-blocks"; }
1091   StringRef getDescription() const final {
1092     return "Test Merging operation in ConversionPatternRewriter";
1093   }
1094   void runOnOperation() override {
1095     MLIRContext *context = &getContext();
1096     mlir::RewritePatternSet patterns(context);
1097     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1098         context);
1099     ConversionTarget target(*context);
1100     target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1101                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1102     target.addIllegalOp<ILLegalOpF>();
1103 
1104     /// Expect the op to have a single block after legalization.
1105     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1106         [&](TestMergeBlocksOp op) -> bool {
1107           return llvm::hasSingleElement(op.body());
1108         });
1109 
1110     /// Only allow `test.br` within test.merge_blocks op.
1111     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
1112       return op->getParentOfType<TestMergeBlocksOp>();
1113     });
1114 
1115     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1116     /// inlined.
1117     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1118         [&](SingleBlockImplicitTerminatorOp op) -> bool {
1119           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1120         });
1121 
1122     DenseSet<Operation *> unlegalizedOps;
1123     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1124                                  &unlegalizedOps);
1125     for (auto *op : unlegalizedOps)
1126       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1127   }
1128 };
1129 } // namespace
1130 
1131 //===----------------------------------------------------------------------===//
1132 // Test Selective Replacement
1133 //===----------------------------------------------------------------------===//
1134 
1135 namespace {
1136 /// A rewrite mechanism to inline the body of the op into its parent, when both
1137 /// ops can have a single block.
1138 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
1139   using OpRewritePattern<TestCastOp>::OpRewritePattern;
1140 
1141   LogicalResult matchAndRewrite(TestCastOp op,
1142                                 PatternRewriter &rewriter) const final {
1143     if (op.getNumOperands() != 2)
1144       return failure();
1145     OperandRange operands = op.getOperands();
1146 
1147     // Replace non-terminator uses with the first operand.
1148     rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
1149       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
1150     });
1151     // Replace everything else with the second operand if the operation isn't
1152     // dead.
1153     rewriter.replaceOp(op, op.getOperand(1));
1154     return success();
1155   }
1156 };
1157 
1158 struct TestSelectiveReplacementPatternDriver
1159     : public PassWrapper<TestSelectiveReplacementPatternDriver,
1160                          OperationPass<>> {
1161   StringRef getArgument() const final {
1162     return "test-pattern-selective-replacement";
1163   }
1164   StringRef getDescription() const final {
1165     return "Test selective replacement in the PatternRewriter";
1166   }
1167   void runOnOperation() override {
1168     MLIRContext *context = &getContext();
1169     mlir::RewritePatternSet patterns(context);
1170     patterns.add<TestSelectiveOpReplacementPattern>(context);
1171     (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
1172                                        std::move(patterns));
1173   }
1174 };
1175 } // namespace
1176 
1177 //===----------------------------------------------------------------------===//
1178 // PassRegistration
1179 //===----------------------------------------------------------------------===//
1180 
1181 namespace mlir {
1182 namespace test {
1183 void registerPatternsTestPass() {
1184   PassRegistration<TestReturnTypeDriver>();
1185 
1186   PassRegistration<TestDerivedAttributeDriver>();
1187 
1188   PassRegistration<TestPatternDriver>();
1189 
1190   PassRegistration<TestLegalizePatternDriver>([] {
1191     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
1192   });
1193 
1194   PassRegistration<TestRemappedValue>();
1195 
1196   PassRegistration<TestUnknownRootOpDriver>();
1197 
1198   PassRegistration<TestTypeConversionDriver>();
1199 
1200   PassRegistration<TestMergeBlocksPatternDriver>();
1201   PassRegistration<TestSelectiveReplacementPatternDriver>();
1202 }
1203 } // namespace test
1204 } // namespace mlir
1205