1fec6c5acSUday Bondhugula //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2fec6c5acSUday Bondhugula //
3fec6c5acSUday Bondhugula // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fec6c5acSUday Bondhugula // See https://llvm.org/LICENSE.txt for license information.
5fec6c5acSUday Bondhugula // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fec6c5acSUday Bondhugula //
7fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
8fec6c5acSUday Bondhugula 
9fec6c5acSUday Bondhugula #include "TestDialect.h"
1026f93d9fSAlex Zinenko #include "mlir/Dialect/StandardOps/IR/Ops.h"
11473bdaf2SAlex Zinenko #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
12c0a6318dSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
132bf423b0SRob Suderman #include "mlir/IR/Matchers.h"
14fec6c5acSUday Bondhugula #include "mlir/Pass/Pass.h"
15fec6c5acSUday Bondhugula #include "mlir/Transforms/DialectConversion.h"
1626f93d9fSAlex Zinenko #include "mlir/Transforms/FoldUtils.h"
17b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
189ba37b3bSJacques Pienaar 
19fec6c5acSUday Bondhugula using namespace mlir;
207776b19eSStephen Neuendorffer using namespace test;
21fec6c5acSUday Bondhugula 
22fec6c5acSUday Bondhugula // Native function for testing NativeCodeCall
23fec6c5acSUday Bondhugula static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
24fec6c5acSUday Bondhugula   return choice.getValue() ? input1 : input2;
25fec6c5acSUday Bondhugula }
26fec6c5acSUday Bondhugula 
2729429d1aSJacques Pienaar static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
2829429d1aSJacques Pienaar   rewriter.create<OpI>(loc, input);
29fec6c5acSUday Bondhugula }
30fec6c5acSUday Bondhugula 
31fec6c5acSUday Bondhugula static void handleNoResultOp(PatternRewriter &rewriter,
32fec6c5acSUday Bondhugula                              OpSymbolBindingNoResult op) {
33fec6c5acSUday Bondhugula   // Turn the no result op to a one-result op.
34fec6c5acSUday Bondhugula   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(),
35fec6c5acSUday Bondhugula                                     op.operand());
36fec6c5acSUday Bondhugula }
37fec6c5acSUday Bondhugula 
3834b5482bSChia-hung Duan static bool getFirstI32Result(Operation *op, Value &value) {
3934b5482bSChia-hung Duan   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
4034b5482bSChia-hung Duan     return false;
4134b5482bSChia-hung Duan   value = op->getResult(0);
4234b5482bSChia-hung Duan   return true;
4334b5482bSChia-hung Duan }
4434b5482bSChia-hung Duan 
4534b5482bSChia-hung Duan static Value bindNativeCodeCallResult(Value value) { return value; }
4634b5482bSChia-hung Duan 
47d7314b3cSChia-hung Duan static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
48d7314b3cSChia-hung Duan                                                               Value input2) {
49d7314b3cSChia-hung Duan   return SmallVector<Value, 2>({input2, input1});
50d7314b3cSChia-hung Duan }
51d7314b3cSChia-hung Duan 
5201641197SAlexEichenberger // Test that natives calls are only called once during rewrites.
5301641197SAlexEichenberger // OpM_Test will return Pi, increased by 1 for each subsequent calls.
5401641197SAlexEichenberger // This let us check the number of times OpM_Test was called by inspecting
5501641197SAlexEichenberger // the returned value in the MLIR output.
5601641197SAlexEichenberger static int64_t opMIncreasingValue = 314159265;
5701641197SAlexEichenberger static Attribute OpMTest(PatternRewriter &rewriter, Value val) {
5801641197SAlexEichenberger   int64_t i = opMIncreasingValue++;
5901641197SAlexEichenberger   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
6001641197SAlexEichenberger }
6101641197SAlexEichenberger 
62fec6c5acSUday Bondhugula namespace {
63fec6c5acSUday Bondhugula #include "TestPatterns.inc"
64fec6c5acSUday Bondhugula } // end anonymous namespace
65fec6c5acSUday Bondhugula 
66fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
67c484c7ddSChia-hung Duan // Test Reduce Pattern Interface
68c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
69c484c7ddSChia-hung Duan 
707776b19eSStephen Neuendorffer void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
71c484c7ddSChia-hung Duan   populateWithGenerated(patterns);
72c484c7ddSChia-hung Duan }
73c484c7ddSChia-hung Duan 
74c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
75fec6c5acSUday Bondhugula // Canonicalizer Driver.
76fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
77fec6c5acSUday Bondhugula 
78fec6c5acSUday Bondhugula namespace {
7926f93d9fSAlex Zinenko struct FoldingPattern : public RewritePattern {
8026f93d9fSAlex Zinenko public:
8126f93d9fSAlex Zinenko   FoldingPattern(MLIRContext *context)
8226f93d9fSAlex Zinenko       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
8326f93d9fSAlex Zinenko                        /*benefit=*/1, context) {}
8426f93d9fSAlex Zinenko 
8526f93d9fSAlex Zinenko   LogicalResult matchAndRewrite(Operation *op,
8626f93d9fSAlex Zinenko                                 PatternRewriter &rewriter) const override {
872b638ed5SKazuaki Ishizaki     // Exercise OperationFolder API for a single-result operation that is folded
8826f93d9fSAlex Zinenko     // upon construction. The operation being created through the folder has an
8926f93d9fSAlex Zinenko     // in-place folder, and it should be still present in the output.
9026f93d9fSAlex Zinenko     // Furthermore, the folder should not crash when attempting to recover the
91a23d0559SKazuaki Ishizaki     // (unchanged) operation result.
9226f93d9fSAlex Zinenko     OperationFolder folder(op->getContext());
9326f93d9fSAlex Zinenko     Value result = folder.create<TestOpInPlaceFold>(
9426f93d9fSAlex Zinenko         rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
9526f93d9fSAlex Zinenko         rewriter.getI32IntegerAttr(0));
9626f93d9fSAlex Zinenko     assert(result);
9726f93d9fSAlex Zinenko     rewriter.replaceOp(op, result);
9826f93d9fSAlex Zinenko     return success();
9926f93d9fSAlex Zinenko   }
10026f93d9fSAlex Zinenko };
10126f93d9fSAlex Zinenko 
102*e4635e63SRiver Riddle /// This pattern creates a foldable operation at the entry point of the block.
103*e4635e63SRiver Riddle /// This tests the situation where the operation folder will need to replace an
104*e4635e63SRiver Riddle /// operation with a previously created constant that does not initially
105*e4635e63SRiver Riddle /// dominate the operation to replace.
106*e4635e63SRiver Riddle struct FolderInsertBeforePreviouslyFoldedConstantPattern
107*e4635e63SRiver Riddle     : public OpRewritePattern<TestCastOp> {
108*e4635e63SRiver Riddle public:
109*e4635e63SRiver Riddle   using OpRewritePattern<TestCastOp>::OpRewritePattern;
110*e4635e63SRiver Riddle 
111*e4635e63SRiver Riddle   LogicalResult matchAndRewrite(TestCastOp op,
112*e4635e63SRiver Riddle                                 PatternRewriter &rewriter) const override {
113*e4635e63SRiver Riddle     if (!op->hasAttr("test_fold_before_previously_folded_op"))
114*e4635e63SRiver Riddle       return failure();
115*e4635e63SRiver Riddle     rewriter.setInsertionPointToStart(op->getBlock());
116*e4635e63SRiver Riddle 
117*e4635e63SRiver Riddle     auto constOp =
118*e4635e63SRiver Riddle         rewriter.create<ConstantOp>(op.getLoc(), rewriter.getBoolAttr(true));
119*e4635e63SRiver Riddle     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
120*e4635e63SRiver Riddle                                             Value(constOp));
121*e4635e63SRiver Riddle     return success();
122*e4635e63SRiver Riddle   }
123*e4635e63SRiver Riddle };
124*e4635e63SRiver Riddle 
12580aca1eaSRiver Riddle struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
126b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-patterns"; }
127b5e22e6dSMehdi Amini   StringRef getDescription() const final { return "Run test dialect patterns"; }
128fec6c5acSUday Bondhugula   void runOnFunction() override {
129dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(&getContext());
1301d909c9aSChris Lattner     populateWithGenerated(patterns);
131fec6c5acSUday Bondhugula 
132fec6c5acSUday Bondhugula     // Verify named pattern is generated with expected name.
133*e4635e63SRiver Riddle     patterns.add<FoldingPattern, TestNamedPatternRule,
134*e4635e63SRiver Riddle                  FolderInsertBeforePreviouslyFoldedConstantPattern>(
135*e4635e63SRiver Riddle         &getContext());
136fec6c5acSUday Bondhugula 
137e21adfa3SRiver Riddle     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
138fec6c5acSUday Bondhugula   }
139fec6c5acSUday Bondhugula };
140fec6c5acSUday Bondhugula } // end anonymous namespace
141fec6c5acSUday Bondhugula 
142fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
143fec6c5acSUday Bondhugula // ReturnType Driver.
144fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
145fec6c5acSUday Bondhugula 
146fec6c5acSUday Bondhugula namespace {
147fec6c5acSUday Bondhugula // Generate ops for each instance where the type can be successfully inferred.
148fec6c5acSUday Bondhugula template <typename OpTy>
149fec6c5acSUday Bondhugula static void invokeCreateWithInferredReturnType(Operation *op) {
150fec6c5acSUday Bondhugula   auto *context = op->getContext();
151fec6c5acSUday Bondhugula   auto fop = op->getParentOfType<FuncOp>();
152fec6c5acSUday Bondhugula   auto location = UnknownLoc::get(context);
153fec6c5acSUday Bondhugula   OpBuilder b(op);
154fec6c5acSUday Bondhugula   b.setInsertionPointAfter(op);
155fec6c5acSUday Bondhugula 
156fec6c5acSUday Bondhugula   // Use permutations of 2 args as operands.
157fec6c5acSUday Bondhugula   assert(fop.getNumArguments() >= 2);
158fec6c5acSUday Bondhugula   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
159fec6c5acSUday Bondhugula     for (int j = 0; j < e; ++j) {
160fec6c5acSUday Bondhugula       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
161fec6c5acSUday Bondhugula       SmallVector<Type, 2> inferredReturnTypes;
1625eae715aSJacques Pienaar       if (succeeded(OpTy::inferReturnTypes(
1635eae715aSJacques Pienaar               context, llvm::None, values, op->getAttrDictionary(),
1645eae715aSJacques Pienaar               op->getRegions(), inferredReturnTypes))) {
165fec6c5acSUday Bondhugula         OperationState state(location, OpTy::getOperationName());
1669db53a18SRiver Riddle         // TODO: Expand to regions.
167bb1d976fSAlex Zinenko         OpTy::build(b, state, values, op->getAttrs());
168fec6c5acSUday Bondhugula         (void)b.createOperation(state);
169fec6c5acSUday Bondhugula       }
170fec6c5acSUday Bondhugula     }
171fec6c5acSUday Bondhugula   }
172fec6c5acSUday Bondhugula }
173fec6c5acSUday Bondhugula 
174fec6c5acSUday Bondhugula static void reifyReturnShape(Operation *op) {
175fec6c5acSUday Bondhugula   OpBuilder b(op);
176fec6c5acSUday Bondhugula 
177fec6c5acSUday Bondhugula   // Use permutations of 2 args as operands.
178fec6c5acSUday Bondhugula   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
179fec6c5acSUday Bondhugula   SmallVector<Value, 2> shapes;
180851d02f6SWenyi Zhao   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
1819b051703SMaheshRavishankar       !llvm::hasSingleElement(shapes))
182fec6c5acSUday Bondhugula     return;
1839b051703SMaheshRavishankar   for (auto it : llvm::enumerate(shapes)) {
184fec6c5acSUday Bondhugula     op->emitRemark() << "value " << it.index() << ": "
185fec6c5acSUday Bondhugula                      << it.value().getDefiningOp();
186fec6c5acSUday Bondhugula   }
1879b051703SMaheshRavishankar }
188fec6c5acSUday Bondhugula 
18980aca1eaSRiver Riddle struct TestReturnTypeDriver
19080aca1eaSRiver Riddle     : public PassWrapper<TestReturnTypeDriver, FunctionPass> {
191e2310704SJulian Gross   void getDependentDialects(DialectRegistry &registry) const override {
192c0a6318dSMatthias Springer     registry.insert<tensor::TensorDialect>();
193e2310704SJulian Gross   }
194b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-return-type"; }
195b5e22e6dSMehdi Amini   StringRef getDescription() const final { return "Run return type functions"; }
196e2310704SJulian Gross 
197fec6c5acSUday Bondhugula   void runOnFunction() override {
198fec6c5acSUday Bondhugula     if (getFunction().getName() == "testCreateFunctions") {
199fec6c5acSUday Bondhugula       std::vector<Operation *> ops;
200fec6c5acSUday Bondhugula       // Collect ops to avoid triggering on inserted ops.
201fec6c5acSUday Bondhugula       for (auto &op : getFunction().getBody().front())
202fec6c5acSUday Bondhugula         ops.push_back(&op);
203fec6c5acSUday Bondhugula       // Generate test patterns for each, but skip terminator.
204fec6c5acSUday Bondhugula       for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
205fec6c5acSUday Bondhugula         // Test create method of each of the Op classes below. The resultant
206fec6c5acSUday Bondhugula         // output would be in reverse order underneath `op` from which
207fec6c5acSUday Bondhugula         // the attributes and regions are used.
208fec6c5acSUday Bondhugula         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
209fec6c5acSUday Bondhugula         invokeCreateWithInferredReturnType<
210fec6c5acSUday Bondhugula             OpWithShapedTypeInferTypeInterfaceOp>(op);
211fec6c5acSUday Bondhugula       };
212fec6c5acSUday Bondhugula       return;
213fec6c5acSUday Bondhugula     }
214fec6c5acSUday Bondhugula     if (getFunction().getName() == "testReifyFunctions") {
215fec6c5acSUday Bondhugula       std::vector<Operation *> ops;
216fec6c5acSUday Bondhugula       // Collect ops to avoid triggering on inserted ops.
217fec6c5acSUday Bondhugula       for (auto &op : getFunction().getBody().front())
218fec6c5acSUday Bondhugula         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
219fec6c5acSUday Bondhugula           ops.push_back(&op);
220fec6c5acSUday Bondhugula       // Generate test patterns for each, but skip terminator.
221fec6c5acSUday Bondhugula       for (auto *op : ops)
222fec6c5acSUday Bondhugula         reifyReturnShape(op);
223fec6c5acSUday Bondhugula     }
224fec6c5acSUday Bondhugula   }
225fec6c5acSUday Bondhugula };
226fec6c5acSUday Bondhugula } // end anonymous namespace
227fec6c5acSUday Bondhugula 
2289ba37b3bSJacques Pienaar namespace {
2299ba37b3bSJacques Pienaar struct TestDerivedAttributeDriver
2309ba37b3bSJacques Pienaar     : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> {
231b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-derived-attr"; }
232b5e22e6dSMehdi Amini   StringRef getDescription() const final {
233b5e22e6dSMehdi Amini     return "Run test derived attributes";
234b5e22e6dSMehdi Amini   }
2359ba37b3bSJacques Pienaar   void runOnFunction() override;
2369ba37b3bSJacques Pienaar };
2379ba37b3bSJacques Pienaar } // end anonymous namespace
2389ba37b3bSJacques Pienaar 
2399ba37b3bSJacques Pienaar void TestDerivedAttributeDriver::runOnFunction() {
2409ba37b3bSJacques Pienaar   getFunction().walk([](DerivedAttributeOpInterface dOp) {
2419ba37b3bSJacques Pienaar     auto dAttr = dOp.materializeDerivedAttributes();
2429ba37b3bSJacques Pienaar     if (!dAttr)
2439ba37b3bSJacques Pienaar       return;
2449ba37b3bSJacques Pienaar     for (auto d : dAttr)
2459ba37b3bSJacques Pienaar       dOp.emitRemark() << d.first << " = " << d.second;
2469ba37b3bSJacques Pienaar   });
2479ba37b3bSJacques Pienaar }
2489ba37b3bSJacques Pienaar 
249fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
250fec6c5acSUday Bondhugula // Legalization Driver.
251fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
252fec6c5acSUday Bondhugula 
253fec6c5acSUday Bondhugula namespace {
254fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
255fec6c5acSUday Bondhugula // Region-Block Rewrite Testing
256fec6c5acSUday Bondhugula 
257fec6c5acSUday Bondhugula /// This pattern is a simple pattern that inlines the first region of a given
258fec6c5acSUday Bondhugula /// operation into the parent region.
259fec6c5acSUday Bondhugula struct TestRegionRewriteBlockMovement : public ConversionPattern {
260fec6c5acSUday Bondhugula   TestRegionRewriteBlockMovement(MLIRContext *ctx)
261fec6c5acSUday Bondhugula       : ConversionPattern("test.region", 1, ctx) {}
262fec6c5acSUday Bondhugula 
263fec6c5acSUday Bondhugula   LogicalResult
264fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
265fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
266fec6c5acSUday Bondhugula     // Inline this region into the parent region.
267fec6c5acSUday Bondhugula     auto &parentRegion = *op->getParentRegion();
268b0750e2dSTres Popp     auto &opRegion = op->getRegion(0);
269fec6c5acSUday Bondhugula     if (op->getAttr("legalizer.should_clone"))
270b0750e2dSTres Popp       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
271fec6c5acSUday Bondhugula     else
272b0750e2dSTres Popp       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
273b0750e2dSTres Popp 
274b0750e2dSTres Popp     if (op->getAttr("legalizer.erase_old_blocks")) {
275b0750e2dSTres Popp       while (!opRegion.empty())
276b0750e2dSTres Popp         rewriter.eraseBlock(&opRegion.front());
277b0750e2dSTres Popp     }
278fec6c5acSUday Bondhugula 
279fec6c5acSUday Bondhugula     // Drop this operation.
280fec6c5acSUday Bondhugula     rewriter.eraseOp(op);
281fec6c5acSUday Bondhugula     return success();
282fec6c5acSUday Bondhugula   }
283fec6c5acSUday Bondhugula };
284fec6c5acSUday Bondhugula /// This pattern is a simple pattern that generates a region containing an
285fec6c5acSUday Bondhugula /// illegal operation.
286fec6c5acSUday Bondhugula struct TestRegionRewriteUndo : public RewritePattern {
287fec6c5acSUday Bondhugula   TestRegionRewriteUndo(MLIRContext *ctx)
288fec6c5acSUday Bondhugula       : RewritePattern("test.region_builder", 1, ctx) {}
289fec6c5acSUday Bondhugula 
290fec6c5acSUday Bondhugula   LogicalResult matchAndRewrite(Operation *op,
291fec6c5acSUday Bondhugula                                 PatternRewriter &rewriter) const final {
292fec6c5acSUday Bondhugula     // Create the region operation with an entry block containing arguments.
293fec6c5acSUday Bondhugula     OperationState newRegion(op->getLoc(), "test.region");
294fec6c5acSUday Bondhugula     newRegion.addRegion();
295fec6c5acSUday Bondhugula     auto *regionOp = rewriter.createOperation(newRegion);
296fec6c5acSUday Bondhugula     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
297fec6c5acSUday Bondhugula     entryBlock->addArgument(rewriter.getIntegerType(64));
298fec6c5acSUday Bondhugula 
299fec6c5acSUday Bondhugula     // Add an explicitly illegal operation to ensure the conversion fails.
300fec6c5acSUday Bondhugula     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
301fec6c5acSUday Bondhugula     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
302fec6c5acSUday Bondhugula 
303fec6c5acSUday Bondhugula     // Drop this operation.
304fec6c5acSUday Bondhugula     rewriter.eraseOp(op);
305fec6c5acSUday Bondhugula     return success();
306fec6c5acSUday Bondhugula   }
307fec6c5acSUday Bondhugula };
308f27f1e8cSAlex Zinenko /// A simple pattern that creates a block at the end of the parent region of the
309f27f1e8cSAlex Zinenko /// matched operation.
310f27f1e8cSAlex Zinenko struct TestCreateBlock : public RewritePattern {
311f27f1e8cSAlex Zinenko   TestCreateBlock(MLIRContext *ctx)
312f27f1e8cSAlex Zinenko       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
313f27f1e8cSAlex Zinenko 
314f27f1e8cSAlex Zinenko   LogicalResult matchAndRewrite(Operation *op,
315f27f1e8cSAlex Zinenko                                 PatternRewriter &rewriter) const final {
316f27f1e8cSAlex Zinenko     Region &region = *op->getParentRegion();
317f27f1e8cSAlex Zinenko     Type i32Type = rewriter.getIntegerType(32);
318f27f1e8cSAlex Zinenko     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
319f27f1e8cSAlex Zinenko     rewriter.create<TerminatorOp>(op->getLoc());
320f27f1e8cSAlex Zinenko     rewriter.replaceOp(op, {});
321f27f1e8cSAlex Zinenko     return success();
322f27f1e8cSAlex Zinenko   }
323f27f1e8cSAlex Zinenko };
324f27f1e8cSAlex Zinenko 
325a23d0559SKazuaki Ishizaki /// A simple pattern that creates a block containing an invalid operation in
326f27f1e8cSAlex Zinenko /// order to trigger the block creation undo mechanism.
327f27f1e8cSAlex Zinenko struct TestCreateIllegalBlock : public RewritePattern {
328f27f1e8cSAlex Zinenko   TestCreateIllegalBlock(MLIRContext *ctx)
329f27f1e8cSAlex Zinenko       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
330f27f1e8cSAlex Zinenko 
331f27f1e8cSAlex Zinenko   LogicalResult matchAndRewrite(Operation *op,
332f27f1e8cSAlex Zinenko                                 PatternRewriter &rewriter) const final {
333f27f1e8cSAlex Zinenko     Region &region = *op->getParentRegion();
334f27f1e8cSAlex Zinenko     Type i32Type = rewriter.getIntegerType(32);
335f27f1e8cSAlex Zinenko     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
336f27f1e8cSAlex Zinenko     // Create an illegal op to ensure the conversion fails.
337f27f1e8cSAlex Zinenko     rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
338f27f1e8cSAlex Zinenko     rewriter.create<TerminatorOp>(op->getLoc());
339f27f1e8cSAlex Zinenko     rewriter.replaceOp(op, {});
340f27f1e8cSAlex Zinenko     return success();
341f27f1e8cSAlex Zinenko   }
342f27f1e8cSAlex Zinenko };
343fec6c5acSUday Bondhugula 
3440816de16SRiver Riddle /// A simple pattern that tests the undo mechanism when replacing the uses of a
3450816de16SRiver Riddle /// block argument.
3460816de16SRiver Riddle struct TestUndoBlockArgReplace : public ConversionPattern {
3470816de16SRiver Riddle   TestUndoBlockArgReplace(MLIRContext *ctx)
3480816de16SRiver Riddle       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
3490816de16SRiver Riddle 
3500816de16SRiver Riddle   LogicalResult
3510816de16SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
3520816de16SRiver Riddle                   ConversionPatternRewriter &rewriter) const final {
3530816de16SRiver Riddle     auto illegalOp =
3540816de16SRiver Riddle         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
355e2b71610SRahul Joshi     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
3560816de16SRiver Riddle                                         illegalOp);
3570816de16SRiver Riddle     rewriter.updateRootInPlace(op, [] {});
3580816de16SRiver Riddle     return success();
3590816de16SRiver Riddle   }
3600816de16SRiver Riddle };
3610816de16SRiver Riddle 
362df48026bSAlex Zinenko /// A rewrite pattern that tests the undo mechanism when erasing a block.
363df48026bSAlex Zinenko struct TestUndoBlockErase : public ConversionPattern {
364df48026bSAlex Zinenko   TestUndoBlockErase(MLIRContext *ctx)
365df48026bSAlex Zinenko       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
366df48026bSAlex Zinenko 
367df48026bSAlex Zinenko   LogicalResult
368df48026bSAlex Zinenko   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
369df48026bSAlex Zinenko                   ConversionPatternRewriter &rewriter) const final {
370df48026bSAlex Zinenko     Block *secondBlock = &*std::next(op->getRegion(0).begin());
371df48026bSAlex Zinenko     rewriter.setInsertionPointToStart(secondBlock);
372df48026bSAlex Zinenko     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
373df48026bSAlex Zinenko     rewriter.eraseBlock(secondBlock);
374df48026bSAlex Zinenko     rewriter.updateRootInPlace(op, [] {});
375df48026bSAlex Zinenko     return success();
376df48026bSAlex Zinenko   }
377df48026bSAlex Zinenko };
378df48026bSAlex Zinenko 
379fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
380fec6c5acSUday Bondhugula // Type-Conversion Rewrite Testing
381fec6c5acSUday Bondhugula 
382fec6c5acSUday Bondhugula /// This patterns erases a region operation that has had a type conversion.
383fec6c5acSUday Bondhugula struct TestDropOpSignatureConversion : public ConversionPattern {
384fec6c5acSUday Bondhugula   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
38576f3c2f3SRiver Riddle       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
386fec6c5acSUday Bondhugula   LogicalResult
387fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
388fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const override {
389fec6c5acSUday Bondhugula     Region &region = op->getRegion(0);
390fec6c5acSUday Bondhugula     Block *entry = &region.front();
391fec6c5acSUday Bondhugula 
392fec6c5acSUday Bondhugula     // Convert the original entry arguments.
3938d67d187SRiver Riddle     TypeConverter &converter = *getTypeConverter();
394fec6c5acSUday Bondhugula     TypeConverter::SignatureConversion result(entry->getNumArguments());
3958d67d187SRiver Riddle     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
3968d67d187SRiver Riddle                                               result)) ||
3978d67d187SRiver Riddle         failed(rewriter.convertRegionTypes(&region, converter, &result)))
398fec6c5acSUday Bondhugula       return failure();
399fec6c5acSUday Bondhugula 
400fec6c5acSUday Bondhugula     // Convert the region signature and just drop the operation.
401fec6c5acSUday Bondhugula     rewriter.eraseOp(op);
402fec6c5acSUday Bondhugula     return success();
403fec6c5acSUday Bondhugula   }
404fec6c5acSUday Bondhugula };
405fec6c5acSUday Bondhugula /// This pattern simply updates the operands of the given operation.
406fec6c5acSUday Bondhugula struct TestPassthroughInvalidOp : public ConversionPattern {
407fec6c5acSUday Bondhugula   TestPassthroughInvalidOp(MLIRContext *ctx)
408fec6c5acSUday Bondhugula       : ConversionPattern("test.invalid", 1, ctx) {}
409fec6c5acSUday Bondhugula   LogicalResult
410fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
411fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
412fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
413fec6c5acSUday Bondhugula                                              llvm::None);
414fec6c5acSUday Bondhugula     return success();
415fec6c5acSUday Bondhugula   }
416fec6c5acSUday Bondhugula };
417fec6c5acSUday Bondhugula /// This pattern handles the case of a split return value.
418fec6c5acSUday Bondhugula struct TestSplitReturnType : public ConversionPattern {
419fec6c5acSUday Bondhugula   TestSplitReturnType(MLIRContext *ctx)
420fec6c5acSUday Bondhugula       : ConversionPattern("test.return", 1, ctx) {}
421fec6c5acSUday Bondhugula   LogicalResult
422fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
423fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
424fec6c5acSUday Bondhugula     // Check for a return of F32.
425fec6c5acSUday Bondhugula     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
426fec6c5acSUday Bondhugula       return failure();
427fec6c5acSUday Bondhugula 
428fec6c5acSUday Bondhugula     // Check if the first operation is a cast operation, if it is we use the
429fec6c5acSUday Bondhugula     // results directly.
430fec6c5acSUday Bondhugula     auto *defOp = operands[0].getDefiningOp();
431fec6c5acSUday Bondhugula     if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
432fec6c5acSUday Bondhugula       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
433fec6c5acSUday Bondhugula       return success();
434fec6c5acSUday Bondhugula     }
435fec6c5acSUday Bondhugula 
436fec6c5acSUday Bondhugula     // Otherwise, fail to match.
437fec6c5acSUday Bondhugula     return failure();
438fec6c5acSUday Bondhugula   }
439fec6c5acSUday Bondhugula };
440fec6c5acSUday Bondhugula 
441fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
442fec6c5acSUday Bondhugula // Multi-Level Type-Conversion Rewrite Testing
443fec6c5acSUday Bondhugula struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
444fec6c5acSUday Bondhugula   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
445fec6c5acSUday Bondhugula       : ConversionPattern("test.type_producer", 1, ctx) {}
446fec6c5acSUday Bondhugula   LogicalResult
447fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
448fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
449fec6c5acSUday Bondhugula     // If the type is I32, change the type to F32.
450fec6c5acSUday Bondhugula     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
451fec6c5acSUday Bondhugula       return failure();
452fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
453fec6c5acSUday Bondhugula     return success();
454fec6c5acSUday Bondhugula   }
455fec6c5acSUday Bondhugula };
456fec6c5acSUday Bondhugula struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
457fec6c5acSUday Bondhugula   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
458fec6c5acSUday Bondhugula       : ConversionPattern("test.type_producer", 1, ctx) {}
459fec6c5acSUday Bondhugula   LogicalResult
460fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
461fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
462fec6c5acSUday Bondhugula     // If the type is F32, change the type to F64.
463fec6c5acSUday Bondhugula     if (!Type(*op->result_type_begin()).isF32())
464fec6c5acSUday Bondhugula       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
465fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
466fec6c5acSUday Bondhugula     return success();
467fec6c5acSUday Bondhugula   }
468fec6c5acSUday Bondhugula };
469fec6c5acSUday Bondhugula struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
470fec6c5acSUday Bondhugula   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
471fec6c5acSUday Bondhugula       : ConversionPattern("test.type_producer", 10, ctx) {}
472fec6c5acSUday Bondhugula   LogicalResult
473fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
474fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
475fec6c5acSUday Bondhugula     // Always convert to B16, even though it is not a legal type. This tests
476fec6c5acSUday Bondhugula     // that values are unmapped correctly.
477fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
478fec6c5acSUday Bondhugula     return success();
479fec6c5acSUday Bondhugula   }
480fec6c5acSUday Bondhugula };
481fec6c5acSUday Bondhugula struct TestUpdateConsumerType : public ConversionPattern {
482fec6c5acSUday Bondhugula   TestUpdateConsumerType(MLIRContext *ctx)
483fec6c5acSUday Bondhugula       : ConversionPattern("test.type_consumer", 1, ctx) {}
484fec6c5acSUday Bondhugula   LogicalResult
485fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
486fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
487fec6c5acSUday Bondhugula     // Verify that the incoming operand has been successfully remapped to F64.
488fec6c5acSUday Bondhugula     if (!operands[0].getType().isF64())
489fec6c5acSUday Bondhugula       return failure();
490fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
491fec6c5acSUday Bondhugula     return success();
492fec6c5acSUday Bondhugula   }
493fec6c5acSUday Bondhugula };
494fec6c5acSUday Bondhugula 
495fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
496fec6c5acSUday Bondhugula // Non-Root Replacement Rewrite Testing
497fec6c5acSUday Bondhugula /// This pattern generates an invalid operation, but replaces it before the
498fec6c5acSUday Bondhugula /// pattern is finished. This checks that we don't need to legalize the
499fec6c5acSUday Bondhugula /// temporary op.
500fec6c5acSUday Bondhugula struct TestNonRootReplacement : public RewritePattern {
501fec6c5acSUday Bondhugula   TestNonRootReplacement(MLIRContext *ctx)
502fec6c5acSUday Bondhugula       : RewritePattern("test.replace_non_root", 1, ctx) {}
503fec6c5acSUday Bondhugula 
504fec6c5acSUday Bondhugula   LogicalResult matchAndRewrite(Operation *op,
505fec6c5acSUday Bondhugula                                 PatternRewriter &rewriter) const final {
506fec6c5acSUday Bondhugula     auto resultType = *op->result_type_begin();
507fec6c5acSUday Bondhugula     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
508fec6c5acSUday Bondhugula     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
509fec6c5acSUday Bondhugula 
510fec6c5acSUday Bondhugula     rewriter.replaceOp(illegalOp, {legalOp});
511fec6c5acSUday Bondhugula     rewriter.replaceOp(op, {illegalOp});
512fec6c5acSUday Bondhugula     return success();
513fec6c5acSUday Bondhugula   }
514fec6c5acSUday Bondhugula };
515bd1ccfe6SRiver Riddle 
516bd1ccfe6SRiver Riddle //===----------------------------------------------------------------------===//
517bd1ccfe6SRiver Riddle // Recursive Rewrite Testing
518bd1ccfe6SRiver Riddle /// This pattern is applied to the same operation multiple times, but has a
519bd1ccfe6SRiver Riddle /// bounded recursion.
520bd1ccfe6SRiver Riddle struct TestBoundedRecursiveRewrite
521bd1ccfe6SRiver Riddle     : public OpRewritePattern<TestRecursiveRewriteOp> {
5222257e4a7SRiver Riddle   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
5232257e4a7SRiver Riddle 
5242257e4a7SRiver Riddle   void initialize() {
525b99bd771SRiver Riddle     // The conversion target handles bounding the recursion of this pattern.
526b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
527b99bd771SRiver Riddle   }
528bd1ccfe6SRiver Riddle 
529bd1ccfe6SRiver Riddle   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
530bd1ccfe6SRiver Riddle                                 PatternRewriter &rewriter) const final {
531bd1ccfe6SRiver Riddle     // Decrement the depth of the op in-place.
532bd1ccfe6SRiver Riddle     rewriter.updateRootInPlace(op, [&] {
5331ffc1aaaSChristian Sigg       op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1));
534bd1ccfe6SRiver Riddle     });
535bd1ccfe6SRiver Riddle     return success();
536bd1ccfe6SRiver Riddle   }
537bd1ccfe6SRiver Riddle };
5385d5df06aSAlex Zinenko 
5395d5df06aSAlex Zinenko struct TestNestedOpCreationUndoRewrite
5405d5df06aSAlex Zinenko     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
5415d5df06aSAlex Zinenko   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
5425d5df06aSAlex Zinenko 
5435d5df06aSAlex Zinenko   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
5445d5df06aSAlex Zinenko                                 PatternRewriter &rewriter) const final {
5455d5df06aSAlex Zinenko     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
5465d5df06aSAlex Zinenko     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
5475d5df06aSAlex Zinenko     return success();
5485d5df06aSAlex Zinenko   };
5495d5df06aSAlex Zinenko };
550a360a978SMehdi Amini 
551a360a978SMehdi Amini // This pattern matches `test.blackhole` and delete this op and its producer.
552a360a978SMehdi Amini struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
553a360a978SMehdi Amini   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
554a360a978SMehdi Amini 
555a360a978SMehdi Amini   LogicalResult matchAndRewrite(BlackHoleOp op,
556a360a978SMehdi Amini                                 PatternRewriter &rewriter) const final {
557a360a978SMehdi Amini     Operation *producer = op.getOperand().getDefiningOp();
558a360a978SMehdi Amini     // Always erase the user before the producer, the framework should handle
559a360a978SMehdi Amini     // this correctly.
560a360a978SMehdi Amini     rewriter.eraseOp(op);
561a360a978SMehdi Amini     rewriter.eraseOp(producer);
562a360a978SMehdi Amini     return success();
563a360a978SMehdi Amini   };
564a360a978SMehdi Amini };
565fec6c5acSUday Bondhugula } // namespace
566fec6c5acSUday Bondhugula 
567fec6c5acSUday Bondhugula namespace {
568fec6c5acSUday Bondhugula struct TestTypeConverter : public TypeConverter {
569fec6c5acSUday Bondhugula   using TypeConverter::TypeConverter;
5705c5dafc5SAlex Zinenko   TestTypeConverter() {
5715c5dafc5SAlex Zinenko     addConversion(convertType);
5724589dd92SRiver Riddle     addArgumentMaterialization(materializeCast);
5734589dd92SRiver Riddle     addSourceMaterialization(materializeCast);
5747cc79984SVladislav Vinogradov 
5757cc79984SVladislav Vinogradov     /// Materialize the cast for one-to-one conversion from i64 to f64.
5767cc79984SVladislav Vinogradov     const auto materializeOneToOneCast =
5777cc79984SVladislav Vinogradov         [](OpBuilder &builder, IntegerType resultType, ValueRange inputs,
5787cc79984SVladislav Vinogradov            Location loc) -> Optional<Value> {
5797cc79984SVladislav Vinogradov       if (resultType.getWidth() == 42 && inputs.size() == 1)
5807cc79984SVladislav Vinogradov         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
5817cc79984SVladislav Vinogradov       return llvm::None;
5827cc79984SVladislav Vinogradov     };
5837cc79984SVladislav Vinogradov     addArgumentMaterialization(materializeOneToOneCast);
5845c5dafc5SAlex Zinenko   }
585fec6c5acSUday Bondhugula 
586fec6c5acSUday Bondhugula   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
587fec6c5acSUday Bondhugula     // Drop I16 types.
588fec6c5acSUday Bondhugula     if (t.isSignlessInteger(16))
589fec6c5acSUday Bondhugula       return success();
590fec6c5acSUday Bondhugula 
591fec6c5acSUday Bondhugula     // Convert I64 to F64.
592fec6c5acSUday Bondhugula     if (t.isSignlessInteger(64)) {
593fec6c5acSUday Bondhugula       results.push_back(FloatType::getF64(t.getContext()));
594fec6c5acSUday Bondhugula       return success();
595fec6c5acSUday Bondhugula     }
596fec6c5acSUday Bondhugula 
5975c5dafc5SAlex Zinenko     // Convert I42 to I43.
5985c5dafc5SAlex Zinenko     if (t.isInteger(42)) {
5991b97cdf8SRiver Riddle       results.push_back(IntegerType::get(t.getContext(), 43));
6005c5dafc5SAlex Zinenko       return success();
6015c5dafc5SAlex Zinenko     }
6025c5dafc5SAlex Zinenko 
603fec6c5acSUday Bondhugula     // Split F32 into F16,F16.
604fec6c5acSUday Bondhugula     if (t.isF32()) {
605fec6c5acSUday Bondhugula       results.assign(2, FloatType::getF16(t.getContext()));
606fec6c5acSUday Bondhugula       return success();
607fec6c5acSUday Bondhugula     }
608fec6c5acSUday Bondhugula 
609fec6c5acSUday Bondhugula     // Otherwise, convert the type directly.
610fec6c5acSUday Bondhugula     results.push_back(t);
611fec6c5acSUday Bondhugula     return success();
612fec6c5acSUday Bondhugula   }
613fec6c5acSUday Bondhugula 
6145c5dafc5SAlex Zinenko   /// Hook for materializing a conversion. This is necessary because we generate
6155c5dafc5SAlex Zinenko   /// 1->N type mappings.
6164589dd92SRiver Riddle   static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
6174589dd92SRiver Riddle                                          ValueRange inputs, Location loc) {
6185c5dafc5SAlex Zinenko     if (inputs.size() == 1)
6195c5dafc5SAlex Zinenko       return inputs[0];
6204589dd92SRiver Riddle     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
6215c5dafc5SAlex Zinenko   }
622fec6c5acSUday Bondhugula };
623fec6c5acSUday Bondhugula 
624fec6c5acSUday Bondhugula struct TestLegalizePatternDriver
62580aca1eaSRiver Riddle     : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
626b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-legalize-patterns"; }
627b5e22e6dSMehdi Amini   StringRef getDescription() const final {
628b5e22e6dSMehdi Amini     return "Run test dialect legalization patterns";
629b5e22e6dSMehdi Amini   }
630fec6c5acSUday Bondhugula   /// The mode of conversion to use with the driver.
631fec6c5acSUday Bondhugula   enum class ConversionMode { Analysis, Full, Partial };
632fec6c5acSUday Bondhugula 
633fec6c5acSUday Bondhugula   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
634fec6c5acSUday Bondhugula 
635722f909fSRiver Riddle   void runOnOperation() override {
636fec6c5acSUday Bondhugula     TestTypeConverter converter;
637dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(&getContext());
6381d909c9aSChris Lattner     populateWithGenerated(patterns);
639dc4e913bSChris Lattner     patterns
640dc4e913bSChris Lattner         .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
641dc4e913bSChris Lattner              TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace,
642dc4e913bSChris Lattner              TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType,
643df48026bSAlex Zinenko              TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
644fec6c5acSUday Bondhugula              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
6455d5df06aSAlex Zinenko              TestNonRootReplacement, TestBoundedRecursiveRewrite,
646a360a978SMehdi Amini              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>(
647a360a978SMehdi Amini             &getContext());
648dc4e913bSChris Lattner     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
6493a506b31SChris Lattner     mlir::populateFuncOpTypeConversionPattern(patterns, converter);
6503a506b31SChris Lattner     mlir::populateCallOpTypeConversionPattern(patterns, converter);
651fec6c5acSUday Bondhugula 
652fec6c5acSUday Bondhugula     // Define the conversion target used for the test.
653fec6c5acSUday Bondhugula     ConversionTarget target(getContext());
654973ddb7dSMehdi Amini     target.addLegalOp<ModuleOp>();
655f27f1e8cSAlex Zinenko     target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
656f27f1e8cSAlex Zinenko                       TerminatorOp>();
657fec6c5acSUday Bondhugula     target
658fec6c5acSUday Bondhugula         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
659fec6c5acSUday Bondhugula     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
660fec6c5acSUday Bondhugula       // Don't allow F32 operands.
661fec6c5acSUday Bondhugula       return llvm::none_of(op.getOperandTypes(),
662fec6c5acSUday Bondhugula                            [](Type type) { return type.isF32(); });
663fec6c5acSUday Bondhugula     });
6648d67d187SRiver Riddle     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
6658d67d187SRiver Riddle       return converter.isSignatureLegal(op.getType()) &&
6668d67d187SRiver Riddle              converter.isLegal(&op.getBody());
6678d67d187SRiver Riddle     });
668fec6c5acSUday Bondhugula 
669fec6c5acSUday Bondhugula     // Expect the type_producer/type_consumer operations to only operate on f64.
670fec6c5acSUday Bondhugula     target.addDynamicallyLegalOp<TestTypeProducerOp>(
671fec6c5acSUday Bondhugula         [](TestTypeProducerOp op) { return op.getType().isF64(); });
672fec6c5acSUday Bondhugula     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
673fec6c5acSUday Bondhugula       return op.getOperand().getType().isF64();
674fec6c5acSUday Bondhugula     });
675fec6c5acSUday Bondhugula 
676fec6c5acSUday Bondhugula     // Check support for marking certain operations as recursively legal.
677fec6c5acSUday Bondhugula     target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) {
678fec6c5acSUday Bondhugula       return static_cast<bool>(
679fec6c5acSUday Bondhugula           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
680fec6c5acSUday Bondhugula     });
681fec6c5acSUday Bondhugula 
682bd1ccfe6SRiver Riddle     // Mark the bound recursion operation as dynamically legal.
683bd1ccfe6SRiver Riddle     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
684bd1ccfe6SRiver Riddle         [](TestRecursiveRewriteOp op) { return op.depth() == 0; });
685bd1ccfe6SRiver Riddle 
686fec6c5acSUday Bondhugula     // Handle a partial conversion.
687fec6c5acSUday Bondhugula     if (mode == ConversionMode::Partial) {
6888de482eaSLucy Fox       DenseSet<Operation *> unlegalizedOps;
6893fffffa8SRiver Riddle       (void)applyPartialConversion(getOperation(), target, std::move(patterns),
6908de482eaSLucy Fox                                    &unlegalizedOps);
6918de482eaSLucy Fox       // Emit remarks for each legalizable operation.
6928de482eaSLucy Fox       for (auto *op : unlegalizedOps)
6938de482eaSLucy Fox         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
694fec6c5acSUday Bondhugula       return;
695fec6c5acSUday Bondhugula     }
696fec6c5acSUday Bondhugula 
697fec6c5acSUday Bondhugula     // Handle a full conversion.
698fec6c5acSUday Bondhugula     if (mode == ConversionMode::Full) {
699fec6c5acSUday Bondhugula       // Check support for marking unknown operations as dynamically legal.
700fec6c5acSUday Bondhugula       target.markUnknownOpDynamicallyLegal([](Operation *op) {
701fec6c5acSUday Bondhugula         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
702fec6c5acSUday Bondhugula       });
703fec6c5acSUday Bondhugula 
7043fffffa8SRiver Riddle       (void)applyFullConversion(getOperation(), target, std::move(patterns));
705fec6c5acSUday Bondhugula       return;
706fec6c5acSUday Bondhugula     }
707fec6c5acSUday Bondhugula 
708fec6c5acSUday Bondhugula     // Otherwise, handle an analysis conversion.
709fec6c5acSUday Bondhugula     assert(mode == ConversionMode::Analysis);
710fec6c5acSUday Bondhugula 
711fec6c5acSUday Bondhugula     // Analyze the convertible operations.
712fec6c5acSUday Bondhugula     DenseSet<Operation *> legalizedOps;
7133fffffa8SRiver Riddle     if (failed(applyAnalysisConversion(getOperation(), target,
7143fffffa8SRiver Riddle                                        std::move(patterns), legalizedOps)))
715fec6c5acSUday Bondhugula       return signalPassFailure();
716fec6c5acSUday Bondhugula 
717fec6c5acSUday Bondhugula     // Emit remarks for each legalizable operation.
718fec6c5acSUday Bondhugula     for (auto *op : legalizedOps)
719fec6c5acSUday Bondhugula       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
720fec6c5acSUday Bondhugula   }
721fec6c5acSUday Bondhugula 
722fec6c5acSUday Bondhugula   /// The mode of conversion to use.
723fec6c5acSUday Bondhugula   ConversionMode mode;
724fec6c5acSUday Bondhugula };
725fec6c5acSUday Bondhugula } // end anonymous namespace
726fec6c5acSUday Bondhugula 
727fec6c5acSUday Bondhugula static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
728fec6c5acSUday Bondhugula     legalizerConversionMode(
729fec6c5acSUday Bondhugula         "test-legalize-mode",
730fec6c5acSUday Bondhugula         llvm::cl::desc("The legalization mode to use with the test driver"),
731fec6c5acSUday Bondhugula         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
732fec6c5acSUday Bondhugula         llvm::cl::values(
733fec6c5acSUday Bondhugula             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
734fec6c5acSUday Bondhugula                        "analysis", "Perform an analysis conversion"),
735fec6c5acSUday Bondhugula             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
736fec6c5acSUday Bondhugula                        "Perform a full conversion"),
737fec6c5acSUday Bondhugula             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
738fec6c5acSUday Bondhugula                        "partial", "Perform a partial conversion")));
739fec6c5acSUday Bondhugula 
740fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
741fec6c5acSUday Bondhugula // ConversionPatternRewriter::getRemappedValue testing. This method is used
7425aacce3dSKazuaki Ishizaki // to get the remapped value of an original value that was replaced using
743fec6c5acSUday Bondhugula // ConversionPatternRewriter.
744fec6c5acSUday Bondhugula namespace {
745fec6c5acSUday Bondhugula /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
746fec6c5acSUday Bondhugula /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
747fec6c5acSUday Bondhugula /// operand twice.
748fec6c5acSUday Bondhugula ///
749fec6c5acSUday Bondhugula /// Example:
750fec6c5acSUday Bondhugula ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
751fec6c5acSUday Bondhugula /// is replaced with:
752fec6c5acSUday Bondhugula ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
753fec6c5acSUday Bondhugula struct OneVResOneVOperandOp1Converter
754fec6c5acSUday Bondhugula     : public OpConversionPattern<OneVResOneVOperandOp1> {
755fec6c5acSUday Bondhugula   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
756fec6c5acSUday Bondhugula 
757fec6c5acSUday Bondhugula   LogicalResult
758fec6c5acSUday Bondhugula   matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
759fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const override {
760fec6c5acSUday Bondhugula     auto origOps = op.getOperands();
761fec6c5acSUday Bondhugula     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
762fec6c5acSUday Bondhugula            "One operand expected");
763fec6c5acSUday Bondhugula     Value origOp = *origOps.begin();
764fec6c5acSUday Bondhugula     SmallVector<Value, 2> remappedOperands;
765fec6c5acSUday Bondhugula     // Replicate the remapped original operand twice. Note that we don't used
766fec6c5acSUday Bondhugula     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
767fec6c5acSUday Bondhugula     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
768fec6c5acSUday Bondhugula     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
769fec6c5acSUday Bondhugula 
770fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
771fec6c5acSUday Bondhugula                                                        remappedOperands);
772fec6c5acSUday Bondhugula     return success();
773fec6c5acSUday Bondhugula   }
774fec6c5acSUday Bondhugula };
775fec6c5acSUday Bondhugula 
77680aca1eaSRiver Riddle struct TestRemappedValue
77780aca1eaSRiver Riddle     : public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
778b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-remapped-value"; }
779b5e22e6dSMehdi Amini   StringRef getDescription() const final {
780b5e22e6dSMehdi Amini     return "Test public remapped value mechanism in ConversionPatternRewriter";
781b5e22e6dSMehdi Amini   }
782fec6c5acSUday Bondhugula   void runOnFunction() override {
783dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(&getContext());
784dc4e913bSChris Lattner     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
785fec6c5acSUday Bondhugula 
786fec6c5acSUday Bondhugula     mlir::ConversionTarget target(getContext());
787973ddb7dSMehdi Amini     target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>();
788fec6c5acSUday Bondhugula     // We make OneVResOneVOperandOp1 legal only when it has more that one
789fec6c5acSUday Bondhugula     // operand. This will trigger the conversion that will replace one-operand
790fec6c5acSUday Bondhugula     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
791fec6c5acSUday Bondhugula     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
792fec6c5acSUday Bondhugula         [](Operation *op) -> bool {
793fec6c5acSUday Bondhugula           return std::distance(op->operand_begin(), op->operand_end()) > 1;
794fec6c5acSUday Bondhugula         });
795fec6c5acSUday Bondhugula 
7963fffffa8SRiver Riddle     if (failed(mlir::applyFullConversion(getFunction(), target,
7973fffffa8SRiver Riddle                                          std::move(patterns)))) {
798fec6c5acSUday Bondhugula       signalPassFailure();
799fec6c5acSUday Bondhugula     }
800fec6c5acSUday Bondhugula   }
801fec6c5acSUday Bondhugula };
802fec6c5acSUday Bondhugula } // end anonymous namespace
803fec6c5acSUday Bondhugula 
80480d7ac3bSRiver Riddle //===----------------------------------------------------------------------===//
80580d7ac3bSRiver Riddle // Test patterns without a specific root operation kind
80680d7ac3bSRiver Riddle //===----------------------------------------------------------------------===//
80780d7ac3bSRiver Riddle 
80880d7ac3bSRiver Riddle namespace {
80980d7ac3bSRiver Riddle /// This pattern matches and removes any operation in the test dialect.
81080d7ac3bSRiver Riddle struct RemoveTestDialectOps : public RewritePattern {
81176f3c2f3SRiver Riddle   RemoveTestDialectOps(MLIRContext *context)
81276f3c2f3SRiver Riddle       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
81380d7ac3bSRiver Riddle 
81480d7ac3bSRiver Riddle   LogicalResult matchAndRewrite(Operation *op,
81580d7ac3bSRiver Riddle                                 PatternRewriter &rewriter) const override {
81680d7ac3bSRiver Riddle     if (!isa<TestDialect>(op->getDialect()))
81780d7ac3bSRiver Riddle       return failure();
81880d7ac3bSRiver Riddle     rewriter.eraseOp(op);
81980d7ac3bSRiver Riddle     return success();
82080d7ac3bSRiver Riddle   }
82180d7ac3bSRiver Riddle };
82280d7ac3bSRiver Riddle 
82380d7ac3bSRiver Riddle struct TestUnknownRootOpDriver
82480d7ac3bSRiver Riddle     : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
825b5e22e6dSMehdi Amini   StringRef getArgument() const final {
826b5e22e6dSMehdi Amini     return "test-legalize-unknown-root-patterns";
827b5e22e6dSMehdi Amini   }
828b5e22e6dSMehdi Amini   StringRef getDescription() const final {
829b5e22e6dSMehdi Amini     return "Test public remapped value mechanism in ConversionPatternRewriter";
830b5e22e6dSMehdi Amini   }
83180d7ac3bSRiver Riddle   void runOnFunction() override {
832dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(&getContext());
83376f3c2f3SRiver Riddle     patterns.add<RemoveTestDialectOps>(&getContext());
83480d7ac3bSRiver Riddle 
83580d7ac3bSRiver Riddle     mlir::ConversionTarget target(getContext());
83680d7ac3bSRiver Riddle     target.addIllegalDialect<TestDialect>();
8373fffffa8SRiver Riddle     if (failed(
8383fffffa8SRiver Riddle             applyPartialConversion(getFunction(), target, std::move(patterns))))
83980d7ac3bSRiver Riddle       signalPassFailure();
84080d7ac3bSRiver Riddle   }
84180d7ac3bSRiver Riddle };
84280d7ac3bSRiver Riddle } // end anonymous namespace
84380d7ac3bSRiver Riddle 
8444589dd92SRiver Riddle //===----------------------------------------------------------------------===//
8454589dd92SRiver Riddle // Test type conversions
8464589dd92SRiver Riddle //===----------------------------------------------------------------------===//
8474589dd92SRiver Riddle 
8484589dd92SRiver Riddle namespace {
8494589dd92SRiver Riddle struct TestTypeConversionProducer
8504589dd92SRiver Riddle     : public OpConversionPattern<TestTypeProducerOp> {
8514589dd92SRiver Riddle   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
8524589dd92SRiver Riddle   LogicalResult
8534589dd92SRiver Riddle   matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands,
8544589dd92SRiver Riddle                   ConversionPatternRewriter &rewriter) const final {
8554589dd92SRiver Riddle     Type resultType = op.getType();
8564589dd92SRiver Riddle     if (resultType.isa<FloatType>())
8574589dd92SRiver Riddle       resultType = rewriter.getF64Type();
8584589dd92SRiver Riddle     else if (resultType.isInteger(16))
8594589dd92SRiver Riddle       resultType = rewriter.getIntegerType(64);
8604589dd92SRiver Riddle     else
8614589dd92SRiver Riddle       return failure();
8624589dd92SRiver Riddle 
8634589dd92SRiver Riddle     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
8644589dd92SRiver Riddle     return success();
8654589dd92SRiver Riddle   }
8664589dd92SRiver Riddle };
8674589dd92SRiver Riddle 
8680409eb28SAlex Zinenko /// Call signature conversion and then fail the rewrite to trigger the undo
8690409eb28SAlex Zinenko /// mechanism.
8700409eb28SAlex Zinenko struct TestSignatureConversionUndo
8710409eb28SAlex Zinenko     : public OpConversionPattern<TestSignatureConversionUndoOp> {
8720409eb28SAlex Zinenko   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
8730409eb28SAlex Zinenko 
8740409eb28SAlex Zinenko   LogicalResult
8750409eb28SAlex Zinenko   matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands,
8760409eb28SAlex Zinenko                   ConversionPatternRewriter &rewriter) const final {
8770409eb28SAlex Zinenko     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
8780409eb28SAlex Zinenko     return failure();
8790409eb28SAlex Zinenko   }
8800409eb28SAlex Zinenko };
8810409eb28SAlex Zinenko 
8820409eb28SAlex Zinenko /// Just forward the operands to the root op. This is essentially a no-op
8830409eb28SAlex Zinenko /// pattern that is used to trigger target materialization.
8840409eb28SAlex Zinenko struct TestTypeConsumerForward
8850409eb28SAlex Zinenko     : public OpConversionPattern<TestTypeConsumerOp> {
8860409eb28SAlex Zinenko   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
8870409eb28SAlex Zinenko 
8880409eb28SAlex Zinenko   LogicalResult
8890409eb28SAlex Zinenko   matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands,
8900409eb28SAlex Zinenko                   ConversionPatternRewriter &rewriter) const final {
8910409eb28SAlex Zinenko     rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); });
8920409eb28SAlex Zinenko     return success();
8930409eb28SAlex Zinenko   }
8940409eb28SAlex Zinenko };
8950409eb28SAlex Zinenko 
8965b91060dSAlex Zinenko struct TestTypeConversionAnotherProducer
8975b91060dSAlex Zinenko     : public OpRewritePattern<TestAnotherTypeProducerOp> {
8985b91060dSAlex Zinenko   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
8995b91060dSAlex Zinenko 
9005b91060dSAlex Zinenko   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
9015b91060dSAlex Zinenko                                 PatternRewriter &rewriter) const final {
9025b91060dSAlex Zinenko     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
9035b91060dSAlex Zinenko     return success();
9045b91060dSAlex Zinenko   }
9055b91060dSAlex Zinenko };
9065b91060dSAlex Zinenko 
9074589dd92SRiver Riddle struct TestTypeConversionDriver
9084589dd92SRiver Riddle     : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
909f9dc2b70SMehdi Amini   void getDependentDialects(DialectRegistry &registry) const override {
910f9dc2b70SMehdi Amini     registry.insert<TestDialect>();
911f9dc2b70SMehdi Amini   }
912b5e22e6dSMehdi Amini   StringRef getArgument() const final {
913b5e22e6dSMehdi Amini     return "test-legalize-type-conversion";
914b5e22e6dSMehdi Amini   }
915b5e22e6dSMehdi Amini   StringRef getDescription() const final {
916b5e22e6dSMehdi Amini     return "Test various type conversion functionalities in DialectConversion";
917b5e22e6dSMehdi Amini   }
918f9dc2b70SMehdi Amini 
9194589dd92SRiver Riddle   void runOnOperation() override {
9204589dd92SRiver Riddle     // Initialize the type converter.
9214589dd92SRiver Riddle     TypeConverter converter;
9224589dd92SRiver Riddle 
9234589dd92SRiver Riddle     /// Add the legal set of type conversions.
9244589dd92SRiver Riddle     converter.addConversion([](Type type) -> Type {
9254589dd92SRiver Riddle       // Treat F64 as legal.
9264589dd92SRiver Riddle       if (type.isF64())
9274589dd92SRiver Riddle         return type;
9284589dd92SRiver Riddle       // Allow converting BF16/F16/F32 to F64.
9294589dd92SRiver Riddle       if (type.isBF16() || type.isF16() || type.isF32())
9304589dd92SRiver Riddle         return FloatType::getF64(type.getContext());
9314589dd92SRiver Riddle       // Otherwise, the type is illegal.
9324589dd92SRiver Riddle       return nullptr;
9334589dd92SRiver Riddle     });
9344589dd92SRiver Riddle     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
9354589dd92SRiver Riddle       // Drop all integer types.
9364589dd92SRiver Riddle       return success();
9374589dd92SRiver Riddle     });
9384589dd92SRiver Riddle 
9394589dd92SRiver Riddle     /// Add the legal set of type materializations.
9404589dd92SRiver Riddle     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
9414589dd92SRiver Riddle                                           ValueRange inputs,
9424589dd92SRiver Riddle                                           Location loc) -> Value {
9434589dd92SRiver Riddle       // Allow casting from F64 back to F32.
9444589dd92SRiver Riddle       if (!resultType.isF16() && inputs.size() == 1 &&
9454589dd92SRiver Riddle           inputs[0].getType().isF64())
9464589dd92SRiver Riddle         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
9474589dd92SRiver Riddle       // Allow producing an i32 or i64 from nothing.
9484589dd92SRiver Riddle       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
9494589dd92SRiver Riddle           inputs.empty())
9504589dd92SRiver Riddle         return builder.create<TestTypeProducerOp>(loc, resultType);
9514589dd92SRiver Riddle       // Allow producing an i64 from an integer.
9524589dd92SRiver Riddle       if (resultType.isa<IntegerType>() && inputs.size() == 1 &&
9534589dd92SRiver Riddle           inputs[0].getType().isa<IntegerType>())
9544589dd92SRiver Riddle         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
9554589dd92SRiver Riddle       // Otherwise, fail.
9564589dd92SRiver Riddle       return nullptr;
9574589dd92SRiver Riddle     });
9584589dd92SRiver Riddle 
9594589dd92SRiver Riddle     // Initialize the conversion target.
9604589dd92SRiver Riddle     mlir::ConversionTarget target(getContext());
9614589dd92SRiver Riddle     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
9624589dd92SRiver Riddle       return op.getType().isF64() || op.getType().isInteger(64);
9634589dd92SRiver Riddle     });
9644589dd92SRiver Riddle     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
9654589dd92SRiver Riddle       return converter.isSignatureLegal(op.getType()) &&
9664589dd92SRiver Riddle              converter.isLegal(&op.getBody());
9674589dd92SRiver Riddle     });
9684589dd92SRiver Riddle     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
9694589dd92SRiver Riddle       // Allow casts from F64 to F32.
9704589dd92SRiver Riddle       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
9714589dd92SRiver Riddle     });
9724589dd92SRiver Riddle 
9734589dd92SRiver Riddle     // Initialize the set of rewrite patterns.
974dc4e913bSChris Lattner     RewritePatternSet patterns(&getContext());
975dc4e913bSChris Lattner     patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
9760409eb28SAlex Zinenko                  TestSignatureConversionUndo>(converter, &getContext());
977dc4e913bSChris Lattner     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
9783a506b31SChris Lattner     mlir::populateFuncOpTypeConversionPattern(patterns, converter);
9794589dd92SRiver Riddle 
9803fffffa8SRiver Riddle     if (failed(applyPartialConversion(getOperation(), target,
9813fffffa8SRiver Riddle                                       std::move(patterns))))
9824589dd92SRiver Riddle       signalPassFailure();
9834589dd92SRiver Riddle   }
9844589dd92SRiver Riddle };
9854589dd92SRiver Riddle } // end anonymous namespace
9864589dd92SRiver Riddle 
987c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
988c8fb6ee3SRiver Riddle // Test Block Merging
989c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
990c8fb6ee3SRiver Riddle 
991e888886cSMaheshRavishankar namespace {
992e888886cSMaheshRavishankar /// A rewriter pattern that tests that blocks can be merged.
993e888886cSMaheshRavishankar struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
994e888886cSMaheshRavishankar   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
995e888886cSMaheshRavishankar 
996e888886cSMaheshRavishankar   LogicalResult
997e888886cSMaheshRavishankar   matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands,
998e888886cSMaheshRavishankar                   ConversionPatternRewriter &rewriter) const final {
999e888886cSMaheshRavishankar     Block &firstBlock = op.body().front();
1000e888886cSMaheshRavishankar     Operation *branchOp = firstBlock.getTerminator();
1001e888886cSMaheshRavishankar     Block *secondBlock = &*(std::next(op.body().begin()));
1002e888886cSMaheshRavishankar     auto succOperands = branchOp->getOperands();
1003e888886cSMaheshRavishankar     SmallVector<Value, 2> replacements(succOperands);
1004e888886cSMaheshRavishankar     rewriter.eraseOp(branchOp);
1005e888886cSMaheshRavishankar     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1006e888886cSMaheshRavishankar     rewriter.updateRootInPlace(op, [] {});
1007e888886cSMaheshRavishankar     return success();
1008e888886cSMaheshRavishankar   }
1009e888886cSMaheshRavishankar };
1010e888886cSMaheshRavishankar 
1011e888886cSMaheshRavishankar /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1012e888886cSMaheshRavishankar struct TestUndoBlocksMerge : public ConversionPattern {
1013e888886cSMaheshRavishankar   TestUndoBlocksMerge(MLIRContext *ctx)
1014e888886cSMaheshRavishankar       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1015e888886cSMaheshRavishankar   LogicalResult
1016e888886cSMaheshRavishankar   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1017e888886cSMaheshRavishankar                   ConversionPatternRewriter &rewriter) const final {
1018e888886cSMaheshRavishankar     Block &firstBlock = op->getRegion(0).front();
1019e888886cSMaheshRavishankar     Operation *branchOp = firstBlock.getTerminator();
1020e888886cSMaheshRavishankar     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1021e888886cSMaheshRavishankar     rewriter.setInsertionPointToStart(secondBlock);
1022e888886cSMaheshRavishankar     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1023e888886cSMaheshRavishankar     auto succOperands = branchOp->getOperands();
1024e888886cSMaheshRavishankar     SmallVector<Value, 2> replacements(succOperands);
1025e888886cSMaheshRavishankar     rewriter.eraseOp(branchOp);
1026e888886cSMaheshRavishankar     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1027e888886cSMaheshRavishankar     rewriter.updateRootInPlace(op, [] {});
1028e888886cSMaheshRavishankar     return success();
1029e888886cSMaheshRavishankar   }
1030e888886cSMaheshRavishankar };
1031e888886cSMaheshRavishankar 
1032e888886cSMaheshRavishankar /// A rewrite mechanism to inline the body of the op into its parent, when both
1033e888886cSMaheshRavishankar /// ops can have a single block.
1034e888886cSMaheshRavishankar struct TestMergeSingleBlockOps
1035e888886cSMaheshRavishankar     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1036e888886cSMaheshRavishankar   using OpConversionPattern<
1037e888886cSMaheshRavishankar       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1038e888886cSMaheshRavishankar 
1039e888886cSMaheshRavishankar   LogicalResult
1040e888886cSMaheshRavishankar   matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands,
1041e888886cSMaheshRavishankar                   ConversionPatternRewriter &rewriter) const final {
1042e888886cSMaheshRavishankar     SingleBlockImplicitTerminatorOp parentOp =
10430bf4a82aSChristian Sigg         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1044e888886cSMaheshRavishankar     if (!parentOp)
1045e888886cSMaheshRavishankar       return failure();
1046e888886cSMaheshRavishankar     Block &innerBlock = op.region().front();
1047e888886cSMaheshRavishankar     TerminatorOp innerTerminator =
1048e888886cSMaheshRavishankar         cast<TerminatorOp>(innerBlock.getTerminator());
10499c7b0c4aSRahul Joshi     rewriter.mergeBlockBefore(&innerBlock, op);
1050e888886cSMaheshRavishankar     rewriter.eraseOp(innerTerminator);
1051e888886cSMaheshRavishankar     rewriter.eraseOp(op);
1052e888886cSMaheshRavishankar     rewriter.updateRootInPlace(op, [] {});
1053e888886cSMaheshRavishankar     return success();
1054e888886cSMaheshRavishankar   }
1055e888886cSMaheshRavishankar };
1056e888886cSMaheshRavishankar 
1057e888886cSMaheshRavishankar struct TestMergeBlocksPatternDriver
1058e888886cSMaheshRavishankar     : public PassWrapper<TestMergeBlocksPatternDriver,
1059e888886cSMaheshRavishankar                          OperationPass<ModuleOp>> {
1060b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-merge-blocks"; }
1061b5e22e6dSMehdi Amini   StringRef getDescription() const final {
1062b5e22e6dSMehdi Amini     return "Test Merging operation in ConversionPatternRewriter";
1063b5e22e6dSMehdi Amini   }
1064e888886cSMaheshRavishankar   void runOnOperation() override {
1065e888886cSMaheshRavishankar     MLIRContext *context = &getContext();
1066dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(context);
1067dc4e913bSChris Lattner     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1068e888886cSMaheshRavishankar         context);
1069e888886cSMaheshRavishankar     ConversionTarget target(*context);
1070973ddb7dSMehdi Amini     target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1071973ddb7dSMehdi Amini                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1072e888886cSMaheshRavishankar     target.addIllegalOp<ILLegalOpF>();
1073e888886cSMaheshRavishankar 
1074e888886cSMaheshRavishankar     /// Expect the op to have a single block after legalization.
1075e888886cSMaheshRavishankar     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1076e888886cSMaheshRavishankar         [&](TestMergeBlocksOp op) -> bool {
1077e888886cSMaheshRavishankar           return llvm::hasSingleElement(op.body());
1078e888886cSMaheshRavishankar         });
1079e888886cSMaheshRavishankar 
1080e888886cSMaheshRavishankar     /// Only allow `test.br` within test.merge_blocks op.
1081e888886cSMaheshRavishankar     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
10820bf4a82aSChristian Sigg       return op->getParentOfType<TestMergeBlocksOp>();
1083e888886cSMaheshRavishankar     });
1084e888886cSMaheshRavishankar 
1085e888886cSMaheshRavishankar     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1086e888886cSMaheshRavishankar     /// inlined.
1087e888886cSMaheshRavishankar     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1088e888886cSMaheshRavishankar         [&](SingleBlockImplicitTerminatorOp op) -> bool {
10890bf4a82aSChristian Sigg           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1090e888886cSMaheshRavishankar         });
1091e888886cSMaheshRavishankar 
1092e888886cSMaheshRavishankar     DenseSet<Operation *> unlegalizedOps;
10933fffffa8SRiver Riddle     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1094e888886cSMaheshRavishankar                                  &unlegalizedOps);
1095e888886cSMaheshRavishankar     for (auto *op : unlegalizedOps)
1096e888886cSMaheshRavishankar       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1097e888886cSMaheshRavishankar   }
1098e888886cSMaheshRavishankar };
1099e888886cSMaheshRavishankar } // namespace
1100e888886cSMaheshRavishankar 
11014589dd92SRiver Riddle //===----------------------------------------------------------------------===//
1102c8fb6ee3SRiver Riddle // Test Selective Replacement
1103c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
1104c8fb6ee3SRiver Riddle 
1105c8fb6ee3SRiver Riddle namespace {
1106c8fb6ee3SRiver Riddle /// A rewrite mechanism to inline the body of the op into its parent, when both
1107c8fb6ee3SRiver Riddle /// ops can have a single block.
1108c8fb6ee3SRiver Riddle struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
1109c8fb6ee3SRiver Riddle   using OpRewritePattern<TestCastOp>::OpRewritePattern;
1110c8fb6ee3SRiver Riddle 
1111c8fb6ee3SRiver Riddle   LogicalResult matchAndRewrite(TestCastOp op,
1112c8fb6ee3SRiver Riddle                                 PatternRewriter &rewriter) const final {
1113c8fb6ee3SRiver Riddle     if (op.getNumOperands() != 2)
1114c8fb6ee3SRiver Riddle       return failure();
1115c8fb6ee3SRiver Riddle     OperandRange operands = op.getOperands();
1116c8fb6ee3SRiver Riddle 
1117c8fb6ee3SRiver Riddle     // Replace non-terminator uses with the first operand.
1118c8fb6ee3SRiver Riddle     rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
1119fe7c0d90SRiver Riddle       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
1120c8fb6ee3SRiver Riddle     });
1121c8fb6ee3SRiver Riddle     // Replace everything else with the second operand if the operation isn't
1122c8fb6ee3SRiver Riddle     // dead.
1123c8fb6ee3SRiver Riddle     rewriter.replaceOp(op, op.getOperand(1));
1124c8fb6ee3SRiver Riddle     return success();
1125c8fb6ee3SRiver Riddle   }
1126c8fb6ee3SRiver Riddle };
1127c8fb6ee3SRiver Riddle 
1128c8fb6ee3SRiver Riddle struct TestSelectiveReplacementPatternDriver
1129c8fb6ee3SRiver Riddle     : public PassWrapper<TestSelectiveReplacementPatternDriver,
1130c8fb6ee3SRiver Riddle                          OperationPass<>> {
1131b5e22e6dSMehdi Amini   StringRef getArgument() const final {
1132b5e22e6dSMehdi Amini     return "test-pattern-selective-replacement";
1133b5e22e6dSMehdi Amini   }
1134b5e22e6dSMehdi Amini   StringRef getDescription() const final {
1135b5e22e6dSMehdi Amini     return "Test selective replacement in the PatternRewriter";
1136b5e22e6dSMehdi Amini   }
1137c8fb6ee3SRiver Riddle   void runOnOperation() override {
1138c8fb6ee3SRiver Riddle     MLIRContext *context = &getContext();
1139dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(context);
1140dc4e913bSChris Lattner     patterns.add<TestSelectiveOpReplacementPattern>(context);
1141e21adfa3SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
1142c8fb6ee3SRiver Riddle                                        std::move(patterns));
1143c8fb6ee3SRiver Riddle   }
1144c8fb6ee3SRiver Riddle };
1145c8fb6ee3SRiver Riddle } // namespace
1146c8fb6ee3SRiver Riddle 
1147c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
11484589dd92SRiver Riddle // PassRegistration
11494589dd92SRiver Riddle //===----------------------------------------------------------------------===//
11504589dd92SRiver Riddle 
1151fec6c5acSUday Bondhugula namespace mlir {
115272c65b69SAlexander Belyaev namespace test {
1153fec6c5acSUday Bondhugula void registerPatternsTestPass() {
1154b5e22e6dSMehdi Amini   PassRegistration<TestReturnTypeDriver>();
1155fec6c5acSUday Bondhugula 
1156b5e22e6dSMehdi Amini   PassRegistration<TestDerivedAttributeDriver>();
11579ba37b3bSJacques Pienaar 
1158b5e22e6dSMehdi Amini   PassRegistration<TestPatternDriver>();
1159fec6c5acSUday Bondhugula 
1160b5e22e6dSMehdi Amini   PassRegistration<TestLegalizePatternDriver>([] {
1161b5e22e6dSMehdi Amini     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
1162fec6c5acSUday Bondhugula   });
1163fec6c5acSUday Bondhugula 
1164b5e22e6dSMehdi Amini   PassRegistration<TestRemappedValue>();
116580d7ac3bSRiver Riddle 
1166b5e22e6dSMehdi Amini   PassRegistration<TestUnknownRootOpDriver>();
11674589dd92SRiver Riddle 
1168b5e22e6dSMehdi Amini   PassRegistration<TestTypeConversionDriver>();
1169e888886cSMaheshRavishankar 
1170b5e22e6dSMehdi Amini   PassRegistration<TestMergeBlocksPatternDriver>();
1171b5e22e6dSMehdi Amini   PassRegistration<TestSelectiveReplacementPatternDriver>();
1172fec6c5acSUday Bondhugula }
117372c65b69SAlexander Belyaev } // namespace test
1174fec6c5acSUday Bondhugula } // namespace mlir
1175