1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 #include "mlir/Transforms/FoldUtils.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 
18 using namespace mlir;
19 using namespace mlir::test;
20 
21 // Native function for testing NativeCodeCall
22 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
23   return choice.getValue() ? input1 : input2;
24 }
25 
26 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
27   rewriter.create<OpI>(loc, input);
28 }
29 
30 static void handleNoResultOp(PatternRewriter &rewriter,
31                              OpSymbolBindingNoResult op) {
32   // Turn the no result op to a one-result op.
33   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(),
34                                     op.operand());
35 }
36 
37 // Test that natives calls are only called once during rewrites.
38 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
39 // This let us check the number of times OpM_Test was called by inspecting
40 // the returned value in the MLIR output.
41 static int64_t opMIncreasingValue = 314159265;
42 static Attribute OpMTest(PatternRewriter &rewriter, Value val) {
43   int64_t i = opMIncreasingValue++;
44   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
45 }
46 
47 namespace {
48 #include "TestPatterns.inc"
49 } // end anonymous namespace
50 
51 //===----------------------------------------------------------------------===//
52 // Canonicalizer Driver.
53 //===----------------------------------------------------------------------===//
54 
55 namespace {
56 struct FoldingPattern : public RewritePattern {
57 public:
58   FoldingPattern(MLIRContext *context)
59       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
60                        /*benefit=*/1, context) {}
61 
62   LogicalResult matchAndRewrite(Operation *op,
63                                 PatternRewriter &rewriter) const override {
64     // Exercice OperationFolder API for a single-result operation that is folded
65     // upon construction. The operation being created through the folder has an
66     // in-place folder, and it should be still present in the output.
67     // Furthermore, the folder should not crash when attempting to recover the
68     // (unchanged) operation result.
69     OperationFolder folder(op->getContext());
70     Value result = folder.create<TestOpInPlaceFold>(
71         rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
72         rewriter.getI32IntegerAttr(0));
73     assert(result);
74     rewriter.replaceOp(op, result);
75     return success();
76   }
77 };
78 
79 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
80   void runOnFunction() override {
81     mlir::OwningRewritePatternList patterns;
82     populateWithGenerated(&getContext(), patterns);
83 
84     // Verify named pattern is generated with expected name.
85     patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext());
86 
87     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
88   }
89 };
90 } // end anonymous namespace
91 
92 //===----------------------------------------------------------------------===//
93 // ReturnType Driver.
94 //===----------------------------------------------------------------------===//
95 
96 namespace {
97 // Generate ops for each instance where the type can be successfully inferred.
98 template <typename OpTy>
99 static void invokeCreateWithInferredReturnType(Operation *op) {
100   auto *context = op->getContext();
101   auto fop = op->getParentOfType<FuncOp>();
102   auto location = UnknownLoc::get(context);
103   OpBuilder b(op);
104   b.setInsertionPointAfter(op);
105 
106   // Use permutations of 2 args as operands.
107   assert(fop.getNumArguments() >= 2);
108   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
109     for (int j = 0; j < e; ++j) {
110       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
111       SmallVector<Type, 2> inferredReturnTypes;
112       if (succeeded(OpTy::inferReturnTypes(
113               context, llvm::None, values, op->getAttrDictionary(),
114               op->getRegions(), inferredReturnTypes))) {
115         OperationState state(location, OpTy::getOperationName());
116         // TODO: Expand to regions.
117         OpTy::build(b, state, values, op->getAttrs());
118         (void)b.createOperation(state);
119       }
120     }
121   }
122 }
123 
124 static void reifyReturnShape(Operation *op) {
125   OpBuilder b(op);
126 
127   // Use permutations of 2 args as operands.
128   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
129   SmallVector<Value, 2> shapes;
130   if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)))
131     return;
132   for (auto it : llvm::enumerate(shapes))
133     op->emitRemark() << "value " << it.index() << ": "
134                      << it.value().getDefiningOp();
135 }
136 
137 struct TestReturnTypeDriver
138     : public PassWrapper<TestReturnTypeDriver, FunctionPass> {
139   void runOnFunction() override {
140     if (getFunction().getName() == "testCreateFunctions") {
141       std::vector<Operation *> ops;
142       // Collect ops to avoid triggering on inserted ops.
143       for (auto &op : getFunction().getBody().front())
144         ops.push_back(&op);
145       // Generate test patterns for each, but skip terminator.
146       for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
147         // Test create method of each of the Op classes below. The resultant
148         // output would be in reverse order underneath `op` from which
149         // the attributes and regions are used.
150         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
151         invokeCreateWithInferredReturnType<
152             OpWithShapedTypeInferTypeInterfaceOp>(op);
153       };
154       return;
155     }
156     if (getFunction().getName() == "testReifyFunctions") {
157       std::vector<Operation *> ops;
158       // Collect ops to avoid triggering on inserted ops.
159       for (auto &op : getFunction().getBody().front())
160         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
161           ops.push_back(&op);
162       // Generate test patterns for each, but skip terminator.
163       for (auto *op : ops)
164         reifyReturnShape(op);
165     }
166   }
167 };
168 } // end anonymous namespace
169 
170 namespace {
171 struct TestDerivedAttributeDriver
172     : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> {
173   void runOnFunction() override;
174 };
175 } // end anonymous namespace
176 
177 void TestDerivedAttributeDriver::runOnFunction() {
178   getFunction().walk([](DerivedAttributeOpInterface dOp) {
179     auto dAttr = dOp.materializeDerivedAttributes();
180     if (!dAttr)
181       return;
182     for (auto d : dAttr)
183       dOp.emitRemark() << d.first << " = " << d.second;
184   });
185 }
186 
187 //===----------------------------------------------------------------------===//
188 // Legalization Driver.
189 //===----------------------------------------------------------------------===//
190 
191 namespace {
192 //===----------------------------------------------------------------------===//
193 // Region-Block Rewrite Testing
194 
195 /// This pattern is a simple pattern that inlines the first region of a given
196 /// operation into the parent region.
197 struct TestRegionRewriteBlockMovement : public ConversionPattern {
198   TestRegionRewriteBlockMovement(MLIRContext *ctx)
199       : ConversionPattern("test.region", 1, ctx) {}
200 
201   LogicalResult
202   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
203                   ConversionPatternRewriter &rewriter) const final {
204     // Inline this region into the parent region.
205     auto &parentRegion = *op->getParentRegion();
206     if (op->getAttr("legalizer.should_clone"))
207       rewriter.cloneRegionBefore(op->getRegion(0), parentRegion,
208                                  parentRegion.end());
209     else
210       rewriter.inlineRegionBefore(op->getRegion(0), parentRegion,
211                                   parentRegion.end());
212 
213     // Drop this operation.
214     rewriter.eraseOp(op);
215     return success();
216   }
217 };
218 /// This pattern is a simple pattern that generates a region containing an
219 /// illegal operation.
220 struct TestRegionRewriteUndo : public RewritePattern {
221   TestRegionRewriteUndo(MLIRContext *ctx)
222       : RewritePattern("test.region_builder", 1, ctx) {}
223 
224   LogicalResult matchAndRewrite(Operation *op,
225                                 PatternRewriter &rewriter) const final {
226     // Create the region operation with an entry block containing arguments.
227     OperationState newRegion(op->getLoc(), "test.region");
228     newRegion.addRegion();
229     auto *regionOp = rewriter.createOperation(newRegion);
230     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
231     entryBlock->addArgument(rewriter.getIntegerType(64));
232 
233     // Add an explicitly illegal operation to ensure the conversion fails.
234     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
235     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
236 
237     // Drop this operation.
238     rewriter.eraseOp(op);
239     return success();
240   }
241 };
242 /// A simple pattern that creates a block at the end of the parent region of the
243 /// matched operation.
244 struct TestCreateBlock : public RewritePattern {
245   TestCreateBlock(MLIRContext *ctx)
246       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
247 
248   LogicalResult matchAndRewrite(Operation *op,
249                                 PatternRewriter &rewriter) const final {
250     Region &region = *op->getParentRegion();
251     Type i32Type = rewriter.getIntegerType(32);
252     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
253     rewriter.create<TerminatorOp>(op->getLoc());
254     rewriter.replaceOp(op, {});
255     return success();
256   }
257 };
258 
259 /// A simple pattern that creates a block containing an invalid operation in
260 /// order to trigger the block creation undo mechanism.
261 struct TestCreateIllegalBlock : public RewritePattern {
262   TestCreateIllegalBlock(MLIRContext *ctx)
263       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
264 
265   LogicalResult matchAndRewrite(Operation *op,
266                                 PatternRewriter &rewriter) const final {
267     Region &region = *op->getParentRegion();
268     Type i32Type = rewriter.getIntegerType(32);
269     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
270     // Create an illegal op to ensure the conversion fails.
271     rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
272     rewriter.create<TerminatorOp>(op->getLoc());
273     rewriter.replaceOp(op, {});
274     return success();
275   }
276 };
277 
278 /// A simple pattern that tests the undo mechanism when replacing the uses of a
279 /// block argument.
280 struct TestUndoBlockArgReplace : public ConversionPattern {
281   TestUndoBlockArgReplace(MLIRContext *ctx)
282       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
283 
284   LogicalResult
285   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
286                   ConversionPatternRewriter &rewriter) const final {
287     auto illegalOp =
288         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
289     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
290                                         illegalOp);
291     rewriter.updateRootInPlace(op, [] {});
292     return success();
293   }
294 };
295 
296 /// A rewrite pattern that tests the undo mechanism when erasing a block.
297 struct TestUndoBlockErase : public ConversionPattern {
298   TestUndoBlockErase(MLIRContext *ctx)
299       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
300 
301   LogicalResult
302   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
303                   ConversionPatternRewriter &rewriter) const final {
304     Block *secondBlock = &*std::next(op->getRegion(0).begin());
305     rewriter.setInsertionPointToStart(secondBlock);
306     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
307     rewriter.eraseBlock(secondBlock);
308     rewriter.updateRootInPlace(op, [] {});
309     return success();
310   }
311 };
312 
313 //===----------------------------------------------------------------------===//
314 // Type-Conversion Rewrite Testing
315 
316 /// This patterns erases a region operation that has had a type conversion.
317 struct TestDropOpSignatureConversion : public ConversionPattern {
318   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
319       : ConversionPattern("test.drop_region_op", 1, converter, ctx) {}
320   LogicalResult
321   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
322                   ConversionPatternRewriter &rewriter) const override {
323     Region &region = op->getRegion(0);
324     Block *entry = &region.front();
325 
326     // Convert the original entry arguments.
327     TypeConverter &converter = *getTypeConverter();
328     TypeConverter::SignatureConversion result(entry->getNumArguments());
329     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
330                                               result)) ||
331         failed(rewriter.convertRegionTypes(&region, converter, &result)))
332       return failure();
333 
334     // Convert the region signature and just drop the operation.
335     rewriter.eraseOp(op);
336     return success();
337   }
338 };
339 /// This pattern simply updates the operands of the given operation.
340 struct TestPassthroughInvalidOp : public ConversionPattern {
341   TestPassthroughInvalidOp(MLIRContext *ctx)
342       : ConversionPattern("test.invalid", 1, ctx) {}
343   LogicalResult
344   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
345                   ConversionPatternRewriter &rewriter) const final {
346     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
347                                              llvm::None);
348     return success();
349   }
350 };
351 /// This pattern handles the case of a split return value.
352 struct TestSplitReturnType : public ConversionPattern {
353   TestSplitReturnType(MLIRContext *ctx)
354       : ConversionPattern("test.return", 1, ctx) {}
355   LogicalResult
356   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
357                   ConversionPatternRewriter &rewriter) const final {
358     // Check for a return of F32.
359     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
360       return failure();
361 
362     // Check if the first operation is a cast operation, if it is we use the
363     // results directly.
364     auto *defOp = operands[0].getDefiningOp();
365     if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
366       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
367       return success();
368     }
369 
370     // Otherwise, fail to match.
371     return failure();
372   }
373 };
374 
375 //===----------------------------------------------------------------------===//
376 // Multi-Level Type-Conversion Rewrite Testing
377 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
378   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
379       : ConversionPattern("test.type_producer", 1, ctx) {}
380   LogicalResult
381   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
382                   ConversionPatternRewriter &rewriter) const final {
383     // If the type is I32, change the type to F32.
384     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
385       return failure();
386     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
387     return success();
388   }
389 };
390 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
391   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
392       : ConversionPattern("test.type_producer", 1, ctx) {}
393   LogicalResult
394   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
395                   ConversionPatternRewriter &rewriter) const final {
396     // If the type is F32, change the type to F64.
397     if (!Type(*op->result_type_begin()).isF32())
398       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
399     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
400     return success();
401   }
402 };
403 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
404   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
405       : ConversionPattern("test.type_producer", 10, ctx) {}
406   LogicalResult
407   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
408                   ConversionPatternRewriter &rewriter) const final {
409     // Always convert to B16, even though it is not a legal type. This tests
410     // that values are unmapped correctly.
411     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
412     return success();
413   }
414 };
415 struct TestUpdateConsumerType : public ConversionPattern {
416   TestUpdateConsumerType(MLIRContext *ctx)
417       : ConversionPattern("test.type_consumer", 1, ctx) {}
418   LogicalResult
419   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
420                   ConversionPatternRewriter &rewriter) const final {
421     // Verify that the incoming operand has been successfully remapped to F64.
422     if (!operands[0].getType().isF64())
423       return failure();
424     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
425     return success();
426   }
427 };
428 
429 //===----------------------------------------------------------------------===//
430 // Non-Root Replacement Rewrite Testing
431 /// This pattern generates an invalid operation, but replaces it before the
432 /// pattern is finished. This checks that we don't need to legalize the
433 /// temporary op.
434 struct TestNonRootReplacement : public RewritePattern {
435   TestNonRootReplacement(MLIRContext *ctx)
436       : RewritePattern("test.replace_non_root", 1, ctx) {}
437 
438   LogicalResult matchAndRewrite(Operation *op,
439                                 PatternRewriter &rewriter) const final {
440     auto resultType = *op->result_type_begin();
441     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
442     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
443 
444     rewriter.replaceOp(illegalOp, {legalOp});
445     rewriter.replaceOp(op, {illegalOp});
446     return success();
447   }
448 };
449 
450 //===----------------------------------------------------------------------===//
451 // Recursive Rewrite Testing
452 /// This pattern is applied to the same operation multiple times, but has a
453 /// bounded recursion.
454 struct TestBoundedRecursiveRewrite
455     : public OpRewritePattern<TestRecursiveRewriteOp> {
456   TestBoundedRecursiveRewrite(MLIRContext *ctx)
457       : OpRewritePattern<TestRecursiveRewriteOp>(ctx) {
458     // The conversion target handles bounding the recursion of this pattern.
459     setHasBoundedRewriteRecursion();
460   }
461 
462   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
463                                 PatternRewriter &rewriter) const final {
464     // Decrement the depth of the op in-place.
465     rewriter.updateRootInPlace(op, [&] {
466       op.setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1));
467     });
468     return success();
469   }
470 };
471 
472 struct TestNestedOpCreationUndoRewrite
473     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
474   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
475 
476   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
477                                 PatternRewriter &rewriter) const final {
478     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
479     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
480     return success();
481   };
482 };
483 } // namespace
484 
485 namespace {
486 struct TestTypeConverter : public TypeConverter {
487   using TypeConverter::TypeConverter;
488   TestTypeConverter() {
489     addConversion(convertType);
490     addArgumentMaterialization(materializeCast);
491     addArgumentMaterialization(materializeOneToOneCast);
492     addSourceMaterialization(materializeCast);
493   }
494 
495   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
496     // Drop I16 types.
497     if (t.isSignlessInteger(16))
498       return success();
499 
500     // Convert I64 to F64.
501     if (t.isSignlessInteger(64)) {
502       results.push_back(FloatType::getF64(t.getContext()));
503       return success();
504     }
505 
506     // Convert I42 to I43.
507     if (t.isInteger(42)) {
508       results.push_back(IntegerType::get(43, t.getContext()));
509       return success();
510     }
511 
512     // Split F32 into F16,F16.
513     if (t.isF32()) {
514       results.assign(2, FloatType::getF16(t.getContext()));
515       return success();
516     }
517 
518     // Otherwise, convert the type directly.
519     results.push_back(t);
520     return success();
521   }
522 
523   /// Hook for materializing a conversion. This is necessary because we generate
524   /// 1->N type mappings.
525   static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
526                                          ValueRange inputs, Location loc) {
527     if (inputs.size() == 1)
528       return inputs[0];
529     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
530   }
531 
532   /// Materialize the cast for one-to-one conversion from i64 to f64.
533   static Optional<Value> materializeOneToOneCast(OpBuilder &builder,
534                                                  IntegerType resultType,
535                                                  ValueRange inputs,
536                                                  Location loc) {
537     if (resultType.getWidth() == 42 && inputs.size() == 1)
538       return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
539     return llvm::None;
540   }
541 };
542 
543 struct TestLegalizePatternDriver
544     : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
545   /// The mode of conversion to use with the driver.
546   enum class ConversionMode { Analysis, Full, Partial };
547 
548   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
549 
550   void runOnOperation() override {
551     TestTypeConverter converter;
552     mlir::OwningRewritePatternList patterns;
553     populateWithGenerated(&getContext(), patterns);
554     patterns.insert<
555         TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
556         TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase,
557         TestPassthroughInvalidOp, TestSplitReturnType,
558         TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
559         TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
560         TestNonRootReplacement, TestBoundedRecursiveRewrite,
561         TestNestedOpCreationUndoRewrite>(&getContext());
562     patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
563     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
564                                               converter);
565     mlir::populateCallOpTypeConversionPattern(patterns, &getContext(),
566                                               converter);
567 
568     // Define the conversion target used for the test.
569     ConversionTarget target(getContext());
570     target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
571     target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
572                       TerminatorOp>();
573     target
574         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
575     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
576       // Don't allow F32 operands.
577       return llvm::none_of(op.getOperandTypes(),
578                            [](Type type) { return type.isF32(); });
579     });
580     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
581       return converter.isSignatureLegal(op.getType()) &&
582              converter.isLegal(&op.getBody());
583     });
584 
585     // Expect the type_producer/type_consumer operations to only operate on f64.
586     target.addDynamicallyLegalOp<TestTypeProducerOp>(
587         [](TestTypeProducerOp op) { return op.getType().isF64(); });
588     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
589       return op.getOperand().getType().isF64();
590     });
591 
592     // Check support for marking certain operations as recursively legal.
593     target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) {
594       return static_cast<bool>(
595           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
596     });
597 
598     // Mark the bound recursion operation as dynamically legal.
599     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
600         [](TestRecursiveRewriteOp op) { return op.depth() == 0; });
601 
602     // Handle a partial conversion.
603     if (mode == ConversionMode::Partial) {
604       DenseSet<Operation *> unlegalizedOps;
605       (void)applyPartialConversion(getOperation(), target, std::move(patterns),
606                                    &unlegalizedOps);
607       // Emit remarks for each legalizable operation.
608       for (auto *op : unlegalizedOps)
609         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
610       return;
611     }
612 
613     // Handle a full conversion.
614     if (mode == ConversionMode::Full) {
615       // Check support for marking unknown operations as dynamically legal.
616       target.markUnknownOpDynamicallyLegal([](Operation *op) {
617         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
618       });
619 
620       (void)applyFullConversion(getOperation(), target, std::move(patterns));
621       return;
622     }
623 
624     // Otherwise, handle an analysis conversion.
625     assert(mode == ConversionMode::Analysis);
626 
627     // Analyze the convertible operations.
628     DenseSet<Operation *> legalizedOps;
629     if (failed(applyAnalysisConversion(getOperation(), target,
630                                        std::move(patterns), legalizedOps)))
631       return signalPassFailure();
632 
633     // Emit remarks for each legalizable operation.
634     for (auto *op : legalizedOps)
635       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
636   }
637 
638   /// The mode of conversion to use.
639   ConversionMode mode;
640 };
641 } // end anonymous namespace
642 
643 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
644     legalizerConversionMode(
645         "test-legalize-mode",
646         llvm::cl::desc("The legalization mode to use with the test driver"),
647         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
648         llvm::cl::values(
649             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
650                        "analysis", "Perform an analysis conversion"),
651             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
652                        "Perform a full conversion"),
653             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
654                        "partial", "Perform a partial conversion")));
655 
656 //===----------------------------------------------------------------------===//
657 // ConversionPatternRewriter::getRemappedValue testing. This method is used
658 // to get the remapped value of an original value that was replaced using
659 // ConversionPatternRewriter.
660 namespace {
661 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
662 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
663 /// operand twice.
664 ///
665 /// Example:
666 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
667 /// is replaced with:
668 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
669 struct OneVResOneVOperandOp1Converter
670     : public OpConversionPattern<OneVResOneVOperandOp1> {
671   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
672 
673   LogicalResult
674   matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
675                   ConversionPatternRewriter &rewriter) const override {
676     auto origOps = op.getOperands();
677     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
678            "One operand expected");
679     Value origOp = *origOps.begin();
680     SmallVector<Value, 2> remappedOperands;
681     // Replicate the remapped original operand twice. Note that we don't used
682     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
683     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
684     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
685 
686     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
687                                                        remappedOperands);
688     return success();
689   }
690 };
691 
692 struct TestRemappedValue
693     : public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
694   void runOnFunction() override {
695     mlir::OwningRewritePatternList patterns;
696     patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
697 
698     mlir::ConversionTarget target(getContext());
699     target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
700     // We make OneVResOneVOperandOp1 legal only when it has more that one
701     // operand. This will trigger the conversion that will replace one-operand
702     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
703     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
704         [](Operation *op) -> bool {
705           return std::distance(op->operand_begin(), op->operand_end()) > 1;
706         });
707 
708     if (failed(mlir::applyFullConversion(getFunction(), target,
709                                          std::move(patterns)))) {
710       signalPassFailure();
711     }
712   }
713 };
714 } // end anonymous namespace
715 
716 //===----------------------------------------------------------------------===//
717 // Test patterns without a specific root operation kind
718 //===----------------------------------------------------------------------===//
719 
720 namespace {
721 /// This pattern matches and removes any operation in the test dialect.
722 struct RemoveTestDialectOps : public RewritePattern {
723   RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
724 
725   LogicalResult matchAndRewrite(Operation *op,
726                                 PatternRewriter &rewriter) const override {
727     if (!isa<TestDialect>(op->getDialect()))
728       return failure();
729     rewriter.eraseOp(op);
730     return success();
731   }
732 };
733 
734 struct TestUnknownRootOpDriver
735     : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
736   void runOnFunction() override {
737     mlir::OwningRewritePatternList patterns;
738     patterns.insert<RemoveTestDialectOps>();
739 
740     mlir::ConversionTarget target(getContext());
741     target.addIllegalDialect<TestDialect>();
742     if (failed(
743             applyPartialConversion(getFunction(), target, std::move(patterns))))
744       signalPassFailure();
745   }
746 };
747 } // end anonymous namespace
748 
749 //===----------------------------------------------------------------------===//
750 // Test type conversions
751 //===----------------------------------------------------------------------===//
752 
753 namespace {
754 struct TestTypeConversionProducer
755     : public OpConversionPattern<TestTypeProducerOp> {
756   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
757   LogicalResult
758   matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands,
759                   ConversionPatternRewriter &rewriter) const final {
760     Type resultType = op.getType();
761     if (resultType.isa<FloatType>())
762       resultType = rewriter.getF64Type();
763     else if (resultType.isInteger(16))
764       resultType = rewriter.getIntegerType(64);
765     else
766       return failure();
767 
768     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
769     return success();
770   }
771 };
772 
773 struct TestTypeConversionDriver
774     : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
775   void getDependentDialects(DialectRegistry &registry) const override {
776     registry.insert<TestDialect>();
777   }
778 
779   void runOnOperation() override {
780     // Initialize the type converter.
781     TypeConverter converter;
782 
783     /// Add the legal set of type conversions.
784     converter.addConversion([](Type type) -> Type {
785       // Treat F64 as legal.
786       if (type.isF64())
787         return type;
788       // Allow converting BF16/F16/F32 to F64.
789       if (type.isBF16() || type.isF16() || type.isF32())
790         return FloatType::getF64(type.getContext());
791       // Otherwise, the type is illegal.
792       return nullptr;
793     });
794     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
795       // Drop all integer types.
796       return success();
797     });
798 
799     /// Add the legal set of type materializations.
800     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
801                                           ValueRange inputs,
802                                           Location loc) -> Value {
803       // Allow casting from F64 back to F32.
804       if (!resultType.isF16() && inputs.size() == 1 &&
805           inputs[0].getType().isF64())
806         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
807       // Allow producing an i32 or i64 from nothing.
808       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
809           inputs.empty())
810         return builder.create<TestTypeProducerOp>(loc, resultType);
811       // Allow producing an i64 from an integer.
812       if (resultType.isa<IntegerType>() && inputs.size() == 1 &&
813           inputs[0].getType().isa<IntegerType>())
814         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
815       // Otherwise, fail.
816       return nullptr;
817     });
818 
819     // Initialize the conversion target.
820     mlir::ConversionTarget target(getContext());
821     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
822       return op.getType().isF64() || op.getType().isInteger(64);
823     });
824     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
825       return converter.isSignatureLegal(op.getType()) &&
826              converter.isLegal(&op.getBody());
827     });
828     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
829       // Allow casts from F64 to F32.
830       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
831     });
832 
833     // Initialize the set of rewrite patterns.
834     OwningRewritePatternList patterns;
835     patterns.insert<TestTypeConversionProducer>(converter, &getContext());
836     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
837                                               converter);
838 
839     if (failed(applyPartialConversion(getOperation(), target,
840                                       std::move(patterns))))
841       signalPassFailure();
842   }
843 };
844 } // end anonymous namespace
845 
846 namespace {
847 /// A rewriter pattern that tests that blocks can be merged.
848 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
849   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
850 
851   LogicalResult
852   matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands,
853                   ConversionPatternRewriter &rewriter) const final {
854     Block &firstBlock = op.body().front();
855     Operation *branchOp = firstBlock.getTerminator();
856     Block *secondBlock = &*(std::next(op.body().begin()));
857     auto succOperands = branchOp->getOperands();
858     SmallVector<Value, 2> replacements(succOperands);
859     rewriter.eraseOp(branchOp);
860     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
861     rewriter.updateRootInPlace(op, [] {});
862     return success();
863   }
864 };
865 
866 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
867 struct TestUndoBlocksMerge : public ConversionPattern {
868   TestUndoBlocksMerge(MLIRContext *ctx)
869       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
870   LogicalResult
871   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
872                   ConversionPatternRewriter &rewriter) const final {
873     Block &firstBlock = op->getRegion(0).front();
874     Operation *branchOp = firstBlock.getTerminator();
875     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
876     rewriter.setInsertionPointToStart(secondBlock);
877     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
878     auto succOperands = branchOp->getOperands();
879     SmallVector<Value, 2> replacements(succOperands);
880     rewriter.eraseOp(branchOp);
881     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
882     rewriter.updateRootInPlace(op, [] {});
883     return success();
884   }
885 };
886 
887 /// A rewrite mechanism to inline the body of the op into its parent, when both
888 /// ops can have a single block.
889 struct TestMergeSingleBlockOps
890     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
891   using OpConversionPattern<
892       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
893 
894   LogicalResult
895   matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands,
896                   ConversionPatternRewriter &rewriter) const final {
897     SingleBlockImplicitTerminatorOp parentOp =
898         op.getParentOfType<SingleBlockImplicitTerminatorOp>();
899     if (!parentOp)
900       return failure();
901     Block &innerBlock = op.region().front();
902     TerminatorOp innerTerminator =
903         cast<TerminatorOp>(innerBlock.getTerminator());
904     rewriter.mergeBlockBefore(&innerBlock, op);
905     rewriter.eraseOp(innerTerminator);
906     rewriter.eraseOp(op);
907     rewriter.updateRootInPlace(op, [] {});
908     return success();
909   }
910 };
911 
912 struct TestMergeBlocksPatternDriver
913     : public PassWrapper<TestMergeBlocksPatternDriver,
914                          OperationPass<ModuleOp>> {
915   void runOnOperation() override {
916     mlir::OwningRewritePatternList patterns;
917     MLIRContext *context = &getContext();
918     patterns
919         .insert<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
920             context);
921     ConversionTarget target(*context);
922     target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp, TerminatorOp,
923                       TestBranchOp, TestTypeConsumerOp, TestTypeProducerOp,
924                       TestReturnOp>();
925     target.addIllegalOp<ILLegalOpF>();
926 
927     /// Expect the op to have a single block after legalization.
928     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
929         [&](TestMergeBlocksOp op) -> bool {
930           return llvm::hasSingleElement(op.body());
931         });
932 
933     /// Only allow `test.br` within test.merge_blocks op.
934     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
935       return op.getParentOfType<TestMergeBlocksOp>();
936     });
937 
938     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
939     /// inlined.
940     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
941         [&](SingleBlockImplicitTerminatorOp op) -> bool {
942           return !op.getParentOfType<SingleBlockImplicitTerminatorOp>();
943         });
944 
945     DenseSet<Operation *> unlegalizedOps;
946     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
947                                  &unlegalizedOps);
948     for (auto *op : unlegalizedOps)
949       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
950   }
951 };
952 } // namespace
953 
954 //===----------------------------------------------------------------------===//
955 // PassRegistration
956 //===----------------------------------------------------------------------===//
957 
958 namespace mlir {
959 namespace test {
960 void registerPatternsTestPass() {
961   PassRegistration<TestReturnTypeDriver>("test-return-type",
962                                          "Run return type functions");
963 
964   PassRegistration<TestDerivedAttributeDriver>("test-derived-attr",
965                                                "Run test derived attributes");
966 
967   PassRegistration<TestPatternDriver>("test-patterns",
968                                       "Run test dialect patterns");
969 
970   PassRegistration<TestLegalizePatternDriver>(
971       "test-legalize-patterns", "Run test dialect legalization patterns", [] {
972         return std::make_unique<TestLegalizePatternDriver>(
973             legalizerConversionMode);
974       });
975 
976   PassRegistration<TestRemappedValue>(
977       "test-remapped-value",
978       "Test public remapped value mechanism in ConversionPatternRewriter");
979 
980   PassRegistration<TestUnknownRootOpDriver>(
981       "test-legalize-unknown-root-patterns",
982       "Test public remapped value mechanism in ConversionPatternRewriter");
983 
984   PassRegistration<TestTypeConversionDriver>(
985       "test-legalize-type-conversion",
986       "Test various type conversion functionalities in DialectConversion");
987 
988   PassRegistration<TestMergeBlocksPatternDriver>{
989       "test-merge-blocks",
990       "Test Merging operation in ConversionPatternRewriter"};
991 }
992 } // namespace test
993 } // namespace mlir
994