1fec6c5acSUday Bondhugula //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2fec6c5acSUday Bondhugula //
3fec6c5acSUday Bondhugula // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fec6c5acSUday Bondhugula // See https://llvm.org/LICENSE.txt for license information.
5fec6c5acSUday Bondhugula // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fec6c5acSUday Bondhugula //
7fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
8fec6c5acSUday Bondhugula 
9fec6c5acSUday Bondhugula #include "TestDialect.h"
109c5982efSAlex Zinenko #include "TestTypes.h"
11a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1223aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1323aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
14c0a6318dSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
152bf423b0SRob Suderman #include "mlir/IR/Matchers.h"
16fec6c5acSUday Bondhugula #include "mlir/Pass/Pass.h"
17fec6c5acSUday Bondhugula #include "mlir/Transforms/DialectConversion.h"
1826f93d9fSAlex Zinenko #include "mlir/Transforms/FoldUtils.h"
19b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
209ba37b3bSJacques Pienaar 
21fec6c5acSUday Bondhugula using namespace mlir;
227776b19eSStephen Neuendorffer using namespace test;
23fec6c5acSUday Bondhugula 
24fec6c5acSUday Bondhugula // Native function for testing NativeCodeCall
chooseOperand(Value input1,Value input2,BoolAttr choice)25fec6c5acSUday Bondhugula static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
26fec6c5acSUday Bondhugula   return choice.getValue() ? input1 : input2;
27fec6c5acSUday Bondhugula }
28fec6c5acSUday Bondhugula 
createOpI(PatternRewriter & rewriter,Location loc,Value input)2929429d1aSJacques Pienaar static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
3029429d1aSJacques Pienaar   rewriter.create<OpI>(loc, input);
31fec6c5acSUday Bondhugula }
32fec6c5acSUday Bondhugula 
handleNoResultOp(PatternRewriter & rewriter,OpSymbolBindingNoResult op)33fec6c5acSUday Bondhugula static void handleNoResultOp(PatternRewriter &rewriter,
34fec6c5acSUday Bondhugula                              OpSymbolBindingNoResult op) {
35fec6c5acSUday Bondhugula   // Turn the no result op to a one-result op.
366a994233SJacques Pienaar   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
376a994233SJacques Pienaar                                     op.getOperand());
38fec6c5acSUday Bondhugula }
39fec6c5acSUday Bondhugula 
getFirstI32Result(Operation * op,Value & value)4034b5482bSChia-hung Duan static bool getFirstI32Result(Operation *op, Value &value) {
4134b5482bSChia-hung Duan   if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
4234b5482bSChia-hung Duan     return false;
4334b5482bSChia-hung Duan   value = op->getResult(0);
4434b5482bSChia-hung Duan   return true;
4534b5482bSChia-hung Duan }
4634b5482bSChia-hung Duan 
bindNativeCodeCallResult(Value value)4734b5482bSChia-hung Duan static Value bindNativeCodeCallResult(Value value) { return value; }
4834b5482bSChia-hung Duan 
bindMultipleNativeCodeCallResult(Value input1,Value input2)49d7314b3cSChia-hung Duan static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
50d7314b3cSChia-hung Duan                                                               Value input2) {
51d7314b3cSChia-hung Duan   return SmallVector<Value, 2>({input2, input1});
52d7314b3cSChia-hung Duan }
53d7314b3cSChia-hung Duan 
5401641197SAlexEichenberger // Test that natives calls are only called once during rewrites.
5501641197SAlexEichenberger // OpM_Test will return Pi, increased by 1 for each subsequent calls.
5601641197SAlexEichenberger // This let us check the number of times OpM_Test was called by inspecting
5701641197SAlexEichenberger // the returned value in the MLIR output.
5801641197SAlexEichenberger static int64_t opMIncreasingValue = 314159265;
opMTest(PatternRewriter & rewriter,Value val)5902b6fb21SMehdi Amini static Attribute opMTest(PatternRewriter &rewriter, Value val) {
6001641197SAlexEichenberger   int64_t i = opMIncreasingValue++;
6101641197SAlexEichenberger   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
6201641197SAlexEichenberger }
6301641197SAlexEichenberger 
64fec6c5acSUday Bondhugula namespace {
65fec6c5acSUday Bondhugula #include "TestPatterns.inc"
66be0a7e9fSMehdi Amini } // namespace
67fec6c5acSUday Bondhugula 
68fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
69c484c7ddSChia-hung Duan // Test Reduce Pattern Interface
70c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
71c484c7ddSChia-hung Duan 
populateTestReductionPatterns(RewritePatternSet & patterns)727776b19eSStephen Neuendorffer void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
73c484c7ddSChia-hung Duan   populateWithGenerated(patterns);
74c484c7ddSChia-hung Duan }
75c484c7ddSChia-hung Duan 
76c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
77fec6c5acSUday Bondhugula // Canonicalizer Driver.
78fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
79fec6c5acSUday Bondhugula 
80fec6c5acSUday Bondhugula namespace {
8126f93d9fSAlex Zinenko struct FoldingPattern : public RewritePattern {
8226f93d9fSAlex Zinenko public:
FoldingPattern__anon71cd160d0211::FoldingPattern8326f93d9fSAlex Zinenko   FoldingPattern(MLIRContext *context)
8426f93d9fSAlex Zinenko       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
8526f93d9fSAlex Zinenko                        /*benefit=*/1, context) {}
8626f93d9fSAlex Zinenko 
matchAndRewrite__anon71cd160d0211::FoldingPattern8726f93d9fSAlex Zinenko   LogicalResult matchAndRewrite(Operation *op,
8826f93d9fSAlex Zinenko                                 PatternRewriter &rewriter) const override {
892b638ed5SKazuaki Ishizaki     // Exercise OperationFolder API for a single-result operation that is folded
9026f93d9fSAlex Zinenko     // upon construction. The operation being created through the folder has an
9126f93d9fSAlex Zinenko     // in-place folder, and it should be still present in the output.
9226f93d9fSAlex Zinenko     // Furthermore, the folder should not crash when attempting to recover the
93a23d0559SKazuaki Ishizaki     // (unchanged) operation result.
9426f93d9fSAlex Zinenko     OperationFolder folder(op->getContext());
9526f93d9fSAlex Zinenko     Value result = folder.create<TestOpInPlaceFold>(
9626f93d9fSAlex Zinenko         rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
9726f93d9fSAlex Zinenko         rewriter.getI32IntegerAttr(0));
9826f93d9fSAlex Zinenko     assert(result);
9926f93d9fSAlex Zinenko     rewriter.replaceOp(op, result);
10026f93d9fSAlex Zinenko     return success();
10126f93d9fSAlex Zinenko   }
10226f93d9fSAlex Zinenko };
10326f93d9fSAlex Zinenko 
104e4635e63SRiver Riddle /// This pattern creates a foldable operation at the entry point of the block.
105e4635e63SRiver Riddle /// This tests the situation where the operation folder will need to replace an
106e4635e63SRiver Riddle /// operation with a previously created constant that does not initially
107e4635e63SRiver Riddle /// dominate the operation to replace.
108e4635e63SRiver Riddle struct FolderInsertBeforePreviouslyFoldedConstantPattern
109e4635e63SRiver Riddle     : public OpRewritePattern<TestCastOp> {
110e4635e63SRiver Riddle public:
111e4635e63SRiver Riddle   using OpRewritePattern<TestCastOp>::OpRewritePattern;
112e4635e63SRiver Riddle 
matchAndRewrite__anon71cd160d0211::FolderInsertBeforePreviouslyFoldedConstantPattern113e4635e63SRiver Riddle   LogicalResult matchAndRewrite(TestCastOp op,
114e4635e63SRiver Riddle                                 PatternRewriter &rewriter) const override {
115e4635e63SRiver Riddle     if (!op->hasAttr("test_fold_before_previously_folded_op"))
116e4635e63SRiver Riddle       return failure();
117e4635e63SRiver Riddle     rewriter.setInsertionPointToStart(op->getBlock());
118e4635e63SRiver Riddle 
119a54f4eaeSMogball     auto constOp = rewriter.create<arith::ConstantOp>(
120a54f4eaeSMogball         op.getLoc(), rewriter.getBoolAttr(true));
121e4635e63SRiver Riddle     rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
122e4635e63SRiver Riddle                                             Value(constOp));
123e4635e63SRiver Riddle     return success();
124e4635e63SRiver Riddle   }
125e4635e63SRiver Riddle };
126e4635e63SRiver Riddle 
12743f0d5f9SMehdi Amini /// This pattern matches test.op_commutative2 with the first operand being
12843f0d5f9SMehdi Amini /// another test.op_commutative2 with a constant on the right side and fold it
12943f0d5f9SMehdi Amini /// away by propagating it as its result. This is intend to check that patterns
13043f0d5f9SMehdi Amini /// are applied after the commutative property moves constant to the right.
13143f0d5f9SMehdi Amini struct FolderCommutativeOp2WithConstant
13243f0d5f9SMehdi Amini     : public OpRewritePattern<TestCommutative2Op> {
13343f0d5f9SMehdi Amini public:
13443f0d5f9SMehdi Amini   using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
13543f0d5f9SMehdi Amini 
matchAndRewrite__anon71cd160d0211::FolderCommutativeOp2WithConstant13643f0d5f9SMehdi Amini   LogicalResult matchAndRewrite(TestCommutative2Op op,
13743f0d5f9SMehdi Amini                                 PatternRewriter &rewriter) const override {
13843f0d5f9SMehdi Amini     auto operand =
13943f0d5f9SMehdi Amini         dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
14043f0d5f9SMehdi Amini     if (!operand)
14143f0d5f9SMehdi Amini       return failure();
14243f0d5f9SMehdi Amini     Attribute constInput;
14343f0d5f9SMehdi Amini     if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
14443f0d5f9SMehdi Amini       return failure();
14543f0d5f9SMehdi Amini     rewriter.replaceOp(op, operand->getOperand(1));
14643f0d5f9SMehdi Amini     return success();
14743f0d5f9SMehdi Amini   }
14843f0d5f9SMehdi Amini };
14943f0d5f9SMehdi Amini 
15041574554SRiver Riddle struct TestPatternDriver
15158ceae95SRiver Riddle     : public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> {
1525e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
1535e50dd04SRiver Riddle 
1547814b559Srkayaith   TestPatternDriver() = default;
TestPatternDriver__anon71cd160d0211::TestPatternDriver1557814b559Srkayaith   TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {}
1567814b559Srkayaith 
getArgument__anon71cd160d0211::TestPatternDriver157b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-patterns"; }
getDescription__anon71cd160d0211::TestPatternDriver158b5e22e6dSMehdi Amini   StringRef getDescription() const final { return "Run test dialect patterns"; }
runOnOperation__anon71cd160d0211::TestPatternDriver15941574554SRiver Riddle   void runOnOperation() override {
160dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(&getContext());
16183e3c6a7SMehdi Amini     populateWithGenerated(patterns);
162fec6c5acSUday Bondhugula 
163fec6c5acSUday Bondhugula     // Verify named pattern is generated with expected name.
164e4635e63SRiver Riddle     patterns.add<FoldingPattern, TestNamedPatternRule,
16543f0d5f9SMehdi Amini                  FolderInsertBeforePreviouslyFoldedConstantPattern,
16643f0d5f9SMehdi Amini                  FolderCommutativeOp2WithConstant>(&getContext());
167fec6c5acSUday Bondhugula 
1687814b559Srkayaith     GreedyRewriteConfig config;
1697814b559Srkayaith     config.useTopDownTraversal = this->useTopDownTraversal;
1707814b559Srkayaith     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
1717814b559Srkayaith                                        config);
172fec6c5acSUday Bondhugula   }
1737814b559Srkayaith 
1747814b559Srkayaith   Option<bool> useTopDownTraversal{
1757814b559Srkayaith       *this, "top-down",
1767814b559Srkayaith       llvm::cl::desc("Seed the worklist in general top-down order"),
1777814b559Srkayaith       llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
178fec6c5acSUday Bondhugula };
179*ba3a9f51SChia-hung Duan 
180*ba3a9f51SChia-hung Duan struct TestStrictPatternDriver
181*ba3a9f51SChia-hung Duan     : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
182*ba3a9f51SChia-hung Duan public:
183*ba3a9f51SChia-hung Duan   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
184*ba3a9f51SChia-hung Duan 
185*ba3a9f51SChia-hung Duan   TestStrictPatternDriver() = default;
TestStrictPatternDriver__anon71cd160d0211::TestStrictPatternDriver186*ba3a9f51SChia-hung Duan   TestStrictPatternDriver(const TestStrictPatternDriver &other)
187*ba3a9f51SChia-hung Duan       : PassWrapper(other) {}
188*ba3a9f51SChia-hung Duan 
getArgument__anon71cd160d0211::TestStrictPatternDriver189*ba3a9f51SChia-hung Duan   StringRef getArgument() const final { return "test-strict-pattern-driver"; }
getDescription__anon71cd160d0211::TestStrictPatternDriver190*ba3a9f51SChia-hung Duan   StringRef getDescription() const final {
191*ba3a9f51SChia-hung Duan     return "Run strict mode of pattern driver";
192*ba3a9f51SChia-hung Duan   }
193*ba3a9f51SChia-hung Duan 
runOnOperation__anon71cd160d0211::TestStrictPatternDriver194*ba3a9f51SChia-hung Duan   void runOnOperation() override {
195*ba3a9f51SChia-hung Duan     mlir::RewritePatternSet patterns(&getContext());
196*ba3a9f51SChia-hung Duan     patterns.add<InsertSameOp, ReplaceWithSameOp, EraseOp>(&getContext());
197*ba3a9f51SChia-hung Duan     SmallVector<Operation *> ops;
198*ba3a9f51SChia-hung Duan     getOperation()->walk([&](Operation *op) {
199*ba3a9f51SChia-hung Duan       StringRef opName = op->getName().getStringRef();
200*ba3a9f51SChia-hung Duan       if (opName == "test.insert_same_op" ||
201*ba3a9f51SChia-hung Duan           opName == "test.replace_with_same_op" || opName == "test.erase_op") {
202*ba3a9f51SChia-hung Duan         ops.push_back(op);
203*ba3a9f51SChia-hung Duan       }
204*ba3a9f51SChia-hung Duan     });
205*ba3a9f51SChia-hung Duan 
206*ba3a9f51SChia-hung Duan     // Check if these transformations introduce visiting of operations that
207*ba3a9f51SChia-hung Duan     // are not in the `ops` set (The new created ops are valid). An invalid
208*ba3a9f51SChia-hung Duan     // operation will trigger the assertion while processing.
209*ba3a9f51SChia-hung Duan     (void)applyOpPatternsAndFold(makeArrayRef(ops), std::move(patterns),
210*ba3a9f51SChia-hung Duan                                  /*strict=*/true);
211*ba3a9f51SChia-hung Duan   }
212*ba3a9f51SChia-hung Duan 
213*ba3a9f51SChia-hung Duan private:
214*ba3a9f51SChia-hung Duan   // New inserted operation is valid for further transformation.
215*ba3a9f51SChia-hung Duan   class InsertSameOp : public RewritePattern {
216*ba3a9f51SChia-hung Duan   public:
InsertSameOp(MLIRContext * context)217*ba3a9f51SChia-hung Duan     InsertSameOp(MLIRContext *context)
218*ba3a9f51SChia-hung Duan         : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
219*ba3a9f51SChia-hung Duan 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const220*ba3a9f51SChia-hung Duan     LogicalResult matchAndRewrite(Operation *op,
221*ba3a9f51SChia-hung Duan                                   PatternRewriter &rewriter) const override {
222*ba3a9f51SChia-hung Duan       if (op->hasAttr("skip"))
223*ba3a9f51SChia-hung Duan         return failure();
224*ba3a9f51SChia-hung Duan 
225*ba3a9f51SChia-hung Duan       Operation *newOp =
226*ba3a9f51SChia-hung Duan           rewriter.create(op->getLoc(), op->getName().getIdentifier(),
227*ba3a9f51SChia-hung Duan                           op->getOperands(), op->getResultTypes());
228*ba3a9f51SChia-hung Duan       op->setAttr("skip", rewriter.getBoolAttr(true));
229*ba3a9f51SChia-hung Duan       newOp->setAttr("skip", rewriter.getBoolAttr(true));
230*ba3a9f51SChia-hung Duan 
231*ba3a9f51SChia-hung Duan       return success();
232*ba3a9f51SChia-hung Duan     }
233*ba3a9f51SChia-hung Duan   };
234*ba3a9f51SChia-hung Duan 
235*ba3a9f51SChia-hung Duan   // Replace an operation may introduce the re-visiting of its users.
236*ba3a9f51SChia-hung Duan   class ReplaceWithSameOp : public RewritePattern {
237*ba3a9f51SChia-hung Duan   public:
ReplaceWithSameOp(MLIRContext * context)238*ba3a9f51SChia-hung Duan     ReplaceWithSameOp(MLIRContext *context)
239*ba3a9f51SChia-hung Duan         : RewritePattern("test.replace_with_same_op", /*benefit=*/1, context) {}
240*ba3a9f51SChia-hung Duan 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const241*ba3a9f51SChia-hung Duan     LogicalResult matchAndRewrite(Operation *op,
242*ba3a9f51SChia-hung Duan                                   PatternRewriter &rewriter) const override {
243*ba3a9f51SChia-hung Duan       Operation *newOp =
244*ba3a9f51SChia-hung Duan           rewriter.create(op->getLoc(), op->getName().getIdentifier(),
245*ba3a9f51SChia-hung Duan                           op->getOperands(), op->getResultTypes());
246*ba3a9f51SChia-hung Duan       rewriter.replaceOp(op, newOp->getResults());
247*ba3a9f51SChia-hung Duan       return success();
248*ba3a9f51SChia-hung Duan     }
249*ba3a9f51SChia-hung Duan   };
250*ba3a9f51SChia-hung Duan 
251*ba3a9f51SChia-hung Duan   // Remove an operation may introduce the re-visiting of its opreands.
252*ba3a9f51SChia-hung Duan   class EraseOp : public RewritePattern {
253*ba3a9f51SChia-hung Duan   public:
EraseOp(MLIRContext * context)254*ba3a9f51SChia-hung Duan     EraseOp(MLIRContext *context)
255*ba3a9f51SChia-hung Duan         : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const256*ba3a9f51SChia-hung Duan     LogicalResult matchAndRewrite(Operation *op,
257*ba3a9f51SChia-hung Duan                                   PatternRewriter &rewriter) const override {
258*ba3a9f51SChia-hung Duan       rewriter.eraseOp(op);
259*ba3a9f51SChia-hung Duan       return success();
260*ba3a9f51SChia-hung Duan     }
261*ba3a9f51SChia-hung Duan   };
262*ba3a9f51SChia-hung Duan };
263*ba3a9f51SChia-hung Duan 
264be0a7e9fSMehdi Amini } // namespace
265fec6c5acSUday Bondhugula 
266fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
267fec6c5acSUday Bondhugula // ReturnType Driver.
268fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
269fec6c5acSUday Bondhugula 
270fec6c5acSUday Bondhugula namespace {
271fec6c5acSUday Bondhugula // Generate ops for each instance where the type can be successfully inferred.
272fec6c5acSUday Bondhugula template <typename OpTy>
invokeCreateWithInferredReturnType(Operation * op)273fec6c5acSUday Bondhugula static void invokeCreateWithInferredReturnType(Operation *op) {
274fec6c5acSUday Bondhugula   auto *context = op->getContext();
27558ceae95SRiver Riddle   auto fop = op->getParentOfType<func::FuncOp>();
276fec6c5acSUday Bondhugula   auto location = UnknownLoc::get(context);
277fec6c5acSUday Bondhugula   OpBuilder b(op);
278fec6c5acSUday Bondhugula   b.setInsertionPointAfter(op);
279fec6c5acSUday Bondhugula 
280fec6c5acSUday Bondhugula   // Use permutations of 2 args as operands.
281fec6c5acSUday Bondhugula   assert(fop.getNumArguments() >= 2);
282fec6c5acSUday Bondhugula   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
283fec6c5acSUday Bondhugula     for (int j = 0; j < e; ++j) {
284fec6c5acSUday Bondhugula       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
285fec6c5acSUday Bondhugula       SmallVector<Type, 2> inferredReturnTypes;
2865eae715aSJacques Pienaar       if (succeeded(OpTy::inferReturnTypes(
2875eae715aSJacques Pienaar               context, llvm::None, values, op->getAttrDictionary(),
2885eae715aSJacques Pienaar               op->getRegions(), inferredReturnTypes))) {
289fec6c5acSUday Bondhugula         OperationState state(location, OpTy::getOperationName());
2909db53a18SRiver Riddle         // TODO: Expand to regions.
291bb1d976fSAlex Zinenko         OpTy::build(b, state, values, op->getAttrs());
29214ecafd0SChia-hung Duan         (void)b.create(state);
293fec6c5acSUday Bondhugula       }
294fec6c5acSUday Bondhugula     }
295fec6c5acSUday Bondhugula   }
296fec6c5acSUday Bondhugula }
297fec6c5acSUday Bondhugula 
reifyReturnShape(Operation * op)298fec6c5acSUday Bondhugula static void reifyReturnShape(Operation *op) {
299fec6c5acSUday Bondhugula   OpBuilder b(op);
300fec6c5acSUday Bondhugula 
301fec6c5acSUday Bondhugula   // Use permutations of 2 args as operands.
302fec6c5acSUday Bondhugula   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
303fec6c5acSUday Bondhugula   SmallVector<Value, 2> shapes;
304851d02f6SWenyi Zhao   if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
3059b051703SMaheshRavishankar       !llvm::hasSingleElement(shapes))
306fec6c5acSUday Bondhugula     return;
30789de9cc8SMehdi Amini   for (const auto &it : llvm::enumerate(shapes)) {
308fec6c5acSUday Bondhugula     op->emitRemark() << "value " << it.index() << ": "
309fec6c5acSUday Bondhugula                      << it.value().getDefiningOp();
310fec6c5acSUday Bondhugula   }
3119b051703SMaheshRavishankar }
312fec6c5acSUday Bondhugula 
31380aca1eaSRiver Riddle struct TestReturnTypeDriver
31458ceae95SRiver Riddle     : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d0411::TestReturnTypeDriver3155e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
3165e50dd04SRiver Riddle 
317e2310704SJulian Gross   void getDependentDialects(DialectRegistry &registry) const override {
318c0a6318dSMatthias Springer     registry.insert<tensor::TensorDialect>();
319e2310704SJulian Gross   }
getArgument__anon71cd160d0411::TestReturnTypeDriver320b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-return-type"; }
getDescription__anon71cd160d0411::TestReturnTypeDriver321b5e22e6dSMehdi Amini   StringRef getDescription() const final { return "Run return type functions"; }
322e2310704SJulian Gross 
runOnOperation__anon71cd160d0411::TestReturnTypeDriver32341574554SRiver Riddle   void runOnOperation() override {
32441574554SRiver Riddle     if (getOperation().getName() == "testCreateFunctions") {
325fec6c5acSUday Bondhugula       std::vector<Operation *> ops;
326fec6c5acSUday Bondhugula       // Collect ops to avoid triggering on inserted ops.
32741574554SRiver Riddle       for (auto &op : getOperation().getBody().front())
328fec6c5acSUday Bondhugula         ops.push_back(&op);
329fec6c5acSUday Bondhugula       // Generate test patterns for each, but skip terminator.
330fec6c5acSUday Bondhugula       for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
331fec6c5acSUday Bondhugula         // Test create method of each of the Op classes below. The resultant
332fec6c5acSUday Bondhugula         // output would be in reverse order underneath `op` from which
333fec6c5acSUday Bondhugula         // the attributes and regions are used.
334fec6c5acSUday Bondhugula         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
335fec6c5acSUday Bondhugula         invokeCreateWithInferredReturnType<
336fec6c5acSUday Bondhugula             OpWithShapedTypeInferTypeInterfaceOp>(op);
337fec6c5acSUday Bondhugula       };
338fec6c5acSUday Bondhugula       return;
339fec6c5acSUday Bondhugula     }
34041574554SRiver Riddle     if (getOperation().getName() == "testReifyFunctions") {
341fec6c5acSUday Bondhugula       std::vector<Operation *> ops;
342fec6c5acSUday Bondhugula       // Collect ops to avoid triggering on inserted ops.
34341574554SRiver Riddle       for (auto &op : getOperation().getBody().front())
344fec6c5acSUday Bondhugula         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
345fec6c5acSUday Bondhugula           ops.push_back(&op);
346fec6c5acSUday Bondhugula       // Generate test patterns for each, but skip terminator.
347fec6c5acSUday Bondhugula       for (auto *op : ops)
348fec6c5acSUday Bondhugula         reifyReturnShape(op);
349fec6c5acSUday Bondhugula     }
350fec6c5acSUday Bondhugula   }
351fec6c5acSUday Bondhugula };
352be0a7e9fSMehdi Amini } // namespace
353fec6c5acSUday Bondhugula 
3549ba37b3bSJacques Pienaar namespace {
3559ba37b3bSJacques Pienaar struct TestDerivedAttributeDriver
35658ceae95SRiver Riddle     : public PassWrapper<TestDerivedAttributeDriver,
35758ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d0511::TestDerivedAttributeDriver3585e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
3595e50dd04SRiver Riddle 
360b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-derived-attr"; }
getDescription__anon71cd160d0511::TestDerivedAttributeDriver361b5e22e6dSMehdi Amini   StringRef getDescription() const final {
362b5e22e6dSMehdi Amini     return "Run test derived attributes";
363b5e22e6dSMehdi Amini   }
36441574554SRiver Riddle   void runOnOperation() override;
3659ba37b3bSJacques Pienaar };
366be0a7e9fSMehdi Amini } // namespace
3679ba37b3bSJacques Pienaar 
runOnOperation()36841574554SRiver Riddle void TestDerivedAttributeDriver::runOnOperation() {
36941574554SRiver Riddle   getOperation().walk([](DerivedAttributeOpInterface dOp) {
3709ba37b3bSJacques Pienaar     auto dAttr = dOp.materializeDerivedAttributes();
3719ba37b3bSJacques Pienaar     if (!dAttr)
3729ba37b3bSJacques Pienaar       return;
3739ba37b3bSJacques Pienaar     for (auto d : dAttr)
3740c7890c8SRiver Riddle       dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
3759ba37b3bSJacques Pienaar   });
3769ba37b3bSJacques Pienaar }
3779ba37b3bSJacques Pienaar 
378fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
379fec6c5acSUday Bondhugula // Legalization Driver.
380fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
381fec6c5acSUday Bondhugula 
382fec6c5acSUday Bondhugula namespace {
383fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
384fec6c5acSUday Bondhugula // Region-Block Rewrite Testing
385fec6c5acSUday Bondhugula 
386fec6c5acSUday Bondhugula /// This pattern is a simple pattern that inlines the first region of a given
387fec6c5acSUday Bondhugula /// operation into the parent region.
388fec6c5acSUday Bondhugula struct TestRegionRewriteBlockMovement : public ConversionPattern {
TestRegionRewriteBlockMovement__anon71cd160d0711::TestRegionRewriteBlockMovement389fec6c5acSUday Bondhugula   TestRegionRewriteBlockMovement(MLIRContext *ctx)
390fec6c5acSUday Bondhugula       : ConversionPattern("test.region", 1, ctx) {}
391fec6c5acSUday Bondhugula 
392fec6c5acSUday Bondhugula   LogicalResult
matchAndRewrite__anon71cd160d0711::TestRegionRewriteBlockMovement393fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
394fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
395fec6c5acSUday Bondhugula     // Inline this region into the parent region.
396fec6c5acSUday Bondhugula     auto &parentRegion = *op->getParentRegion();
397b0750e2dSTres Popp     auto &opRegion = op->getRegion(0);
398fec6c5acSUday Bondhugula     if (op->getAttr("legalizer.should_clone"))
399b0750e2dSTres Popp       rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
400fec6c5acSUday Bondhugula     else
401b0750e2dSTres Popp       rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
402b0750e2dSTres Popp 
403b0750e2dSTres Popp     if (op->getAttr("legalizer.erase_old_blocks")) {
404b0750e2dSTres Popp       while (!opRegion.empty())
405b0750e2dSTres Popp         rewriter.eraseBlock(&opRegion.front());
406b0750e2dSTres Popp     }
407fec6c5acSUday Bondhugula 
408fec6c5acSUday Bondhugula     // Drop this operation.
409fec6c5acSUday Bondhugula     rewriter.eraseOp(op);
410fec6c5acSUday Bondhugula     return success();
411fec6c5acSUday Bondhugula   }
412fec6c5acSUday Bondhugula };
413fec6c5acSUday Bondhugula /// This pattern is a simple pattern that generates a region containing an
414fec6c5acSUday Bondhugula /// illegal operation.
415fec6c5acSUday Bondhugula struct TestRegionRewriteUndo : public RewritePattern {
TestRegionRewriteUndo__anon71cd160d0711::TestRegionRewriteUndo416fec6c5acSUday Bondhugula   TestRegionRewriteUndo(MLIRContext *ctx)
417fec6c5acSUday Bondhugula       : RewritePattern("test.region_builder", 1, ctx) {}
418fec6c5acSUday Bondhugula 
matchAndRewrite__anon71cd160d0711::TestRegionRewriteUndo419fec6c5acSUday Bondhugula   LogicalResult matchAndRewrite(Operation *op,
420fec6c5acSUday Bondhugula                                 PatternRewriter &rewriter) const final {
421fec6c5acSUday Bondhugula     // Create the region operation with an entry block containing arguments.
422fec6c5acSUday Bondhugula     OperationState newRegion(op->getLoc(), "test.region");
423fec6c5acSUday Bondhugula     newRegion.addRegion();
42414ecafd0SChia-hung Duan     auto *regionOp = rewriter.create(newRegion);
425fec6c5acSUday Bondhugula     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
426e084679fSRiver Riddle     entryBlock->addArgument(rewriter.getIntegerType(64),
427e084679fSRiver Riddle                             rewriter.getUnknownLoc());
428fec6c5acSUday Bondhugula 
429fec6c5acSUday Bondhugula     // Add an explicitly illegal operation to ensure the conversion fails.
430fec6c5acSUday Bondhugula     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
431fec6c5acSUday Bondhugula     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
432fec6c5acSUday Bondhugula 
433fec6c5acSUday Bondhugula     // Drop this operation.
434fec6c5acSUday Bondhugula     rewriter.eraseOp(op);
435fec6c5acSUday Bondhugula     return success();
436fec6c5acSUday Bondhugula   }
437fec6c5acSUday Bondhugula };
438f27f1e8cSAlex Zinenko /// A simple pattern that creates a block at the end of the parent region of the
439f27f1e8cSAlex Zinenko /// matched operation.
440f27f1e8cSAlex Zinenko struct TestCreateBlock : public RewritePattern {
TestCreateBlock__anon71cd160d0711::TestCreateBlock441f27f1e8cSAlex Zinenko   TestCreateBlock(MLIRContext *ctx)
442f27f1e8cSAlex Zinenko       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
443f27f1e8cSAlex Zinenko 
matchAndRewrite__anon71cd160d0711::TestCreateBlock444f27f1e8cSAlex Zinenko   LogicalResult matchAndRewrite(Operation *op,
445f27f1e8cSAlex Zinenko                                 PatternRewriter &rewriter) const final {
446f27f1e8cSAlex Zinenko     Region &region = *op->getParentRegion();
447f27f1e8cSAlex Zinenko     Type i32Type = rewriter.getIntegerType(32);
448e084679fSRiver Riddle     Location loc = op->getLoc();
449e084679fSRiver Riddle     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
450e084679fSRiver Riddle     rewriter.create<TerminatorOp>(loc);
451f27f1e8cSAlex Zinenko     rewriter.replaceOp(op, {});
452f27f1e8cSAlex Zinenko     return success();
453f27f1e8cSAlex Zinenko   }
454f27f1e8cSAlex Zinenko };
455f27f1e8cSAlex Zinenko 
456a23d0559SKazuaki Ishizaki /// A simple pattern that creates a block containing an invalid operation in
457f27f1e8cSAlex Zinenko /// order to trigger the block creation undo mechanism.
458f27f1e8cSAlex Zinenko struct TestCreateIllegalBlock : public RewritePattern {
TestCreateIllegalBlock__anon71cd160d0711::TestCreateIllegalBlock459f27f1e8cSAlex Zinenko   TestCreateIllegalBlock(MLIRContext *ctx)
460f27f1e8cSAlex Zinenko       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
461f27f1e8cSAlex Zinenko 
matchAndRewrite__anon71cd160d0711::TestCreateIllegalBlock462f27f1e8cSAlex Zinenko   LogicalResult matchAndRewrite(Operation *op,
463f27f1e8cSAlex Zinenko                                 PatternRewriter &rewriter) const final {
464f27f1e8cSAlex Zinenko     Region &region = *op->getParentRegion();
465f27f1e8cSAlex Zinenko     Type i32Type = rewriter.getIntegerType(32);
466e084679fSRiver Riddle     Location loc = op->getLoc();
467e084679fSRiver Riddle     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
468f27f1e8cSAlex Zinenko     // Create an illegal op to ensure the conversion fails.
469e084679fSRiver Riddle     rewriter.create<ILLegalOpF>(loc, i32Type);
470e084679fSRiver Riddle     rewriter.create<TerminatorOp>(loc);
471f27f1e8cSAlex Zinenko     rewriter.replaceOp(op, {});
472f27f1e8cSAlex Zinenko     return success();
473f27f1e8cSAlex Zinenko   }
474f27f1e8cSAlex Zinenko };
475fec6c5acSUday Bondhugula 
4760816de16SRiver Riddle /// A simple pattern that tests the undo mechanism when replacing the uses of a
4770816de16SRiver Riddle /// block argument.
4780816de16SRiver Riddle struct TestUndoBlockArgReplace : public ConversionPattern {
TestUndoBlockArgReplace__anon71cd160d0711::TestUndoBlockArgReplace4790816de16SRiver Riddle   TestUndoBlockArgReplace(MLIRContext *ctx)
4800816de16SRiver Riddle       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
4810816de16SRiver Riddle 
4820816de16SRiver Riddle   LogicalResult
matchAndRewrite__anon71cd160d0711::TestUndoBlockArgReplace4830816de16SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
4840816de16SRiver Riddle                   ConversionPatternRewriter &rewriter) const final {
4850816de16SRiver Riddle     auto illegalOp =
4860816de16SRiver Riddle         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
487e2b71610SRahul Joshi     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
4880816de16SRiver Riddle                                         illegalOp);
4890816de16SRiver Riddle     rewriter.updateRootInPlace(op, [] {});
4900816de16SRiver Riddle     return success();
4910816de16SRiver Riddle   }
4920816de16SRiver Riddle };
4930816de16SRiver Riddle 
494df48026bSAlex Zinenko /// A rewrite pattern that tests the undo mechanism when erasing a block.
495df48026bSAlex Zinenko struct TestUndoBlockErase : public ConversionPattern {
TestUndoBlockErase__anon71cd160d0711::TestUndoBlockErase496df48026bSAlex Zinenko   TestUndoBlockErase(MLIRContext *ctx)
497df48026bSAlex Zinenko       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
498df48026bSAlex Zinenko 
499df48026bSAlex Zinenko   LogicalResult
matchAndRewrite__anon71cd160d0711::TestUndoBlockErase500df48026bSAlex Zinenko   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
501df48026bSAlex Zinenko                   ConversionPatternRewriter &rewriter) const final {
502df48026bSAlex Zinenko     Block *secondBlock = &*std::next(op->getRegion(0).begin());
503df48026bSAlex Zinenko     rewriter.setInsertionPointToStart(secondBlock);
504df48026bSAlex Zinenko     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
505df48026bSAlex Zinenko     rewriter.eraseBlock(secondBlock);
506df48026bSAlex Zinenko     rewriter.updateRootInPlace(op, [] {});
507df48026bSAlex Zinenko     return success();
508df48026bSAlex Zinenko   }
509df48026bSAlex Zinenko };
510df48026bSAlex Zinenko 
511fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
512fec6c5acSUday Bondhugula // Type-Conversion Rewrite Testing
513fec6c5acSUday Bondhugula 
514fec6c5acSUday Bondhugula /// This patterns erases a region operation that has had a type conversion.
515fec6c5acSUday Bondhugula struct TestDropOpSignatureConversion : public ConversionPattern {
TestDropOpSignatureConversion__anon71cd160d0711::TestDropOpSignatureConversion516fec6c5acSUday Bondhugula   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
51776f3c2f3SRiver Riddle       : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
518fec6c5acSUday Bondhugula   LogicalResult
matchAndRewrite__anon71cd160d0711::TestDropOpSignatureConversion519fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
520fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const override {
521fec6c5acSUday Bondhugula     Region &region = op->getRegion(0);
522fec6c5acSUday Bondhugula     Block *entry = &region.front();
523fec6c5acSUday Bondhugula 
524fec6c5acSUday Bondhugula     // Convert the original entry arguments.
5258d67d187SRiver Riddle     TypeConverter &converter = *getTypeConverter();
526fec6c5acSUday Bondhugula     TypeConverter::SignatureConversion result(entry->getNumArguments());
5278d67d187SRiver Riddle     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
5288d67d187SRiver Riddle                                               result)) ||
5298d67d187SRiver Riddle         failed(rewriter.convertRegionTypes(&region, converter, &result)))
530fec6c5acSUday Bondhugula       return failure();
531fec6c5acSUday Bondhugula 
532fec6c5acSUday Bondhugula     // Convert the region signature and just drop the operation.
533fec6c5acSUday Bondhugula     rewriter.eraseOp(op);
534fec6c5acSUday Bondhugula     return success();
535fec6c5acSUday Bondhugula   }
536fec6c5acSUday Bondhugula };
537fec6c5acSUday Bondhugula /// This pattern simply updates the operands of the given operation.
538fec6c5acSUday Bondhugula struct TestPassthroughInvalidOp : public ConversionPattern {
TestPassthroughInvalidOp__anon71cd160d0711::TestPassthroughInvalidOp539fec6c5acSUday Bondhugula   TestPassthroughInvalidOp(MLIRContext *ctx)
540fec6c5acSUday Bondhugula       : ConversionPattern("test.invalid", 1, ctx) {}
541fec6c5acSUday Bondhugula   LogicalResult
matchAndRewrite__anon71cd160d0711::TestPassthroughInvalidOp542fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
543fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
544fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
545fec6c5acSUday Bondhugula                                              llvm::None);
546fec6c5acSUday Bondhugula     return success();
547fec6c5acSUday Bondhugula   }
548fec6c5acSUday Bondhugula };
549fec6c5acSUday Bondhugula /// This pattern handles the case of a split return value.
550fec6c5acSUday Bondhugula struct TestSplitReturnType : public ConversionPattern {
TestSplitReturnType__anon71cd160d0711::TestSplitReturnType551fec6c5acSUday Bondhugula   TestSplitReturnType(MLIRContext *ctx)
552fec6c5acSUday Bondhugula       : ConversionPattern("test.return", 1, ctx) {}
553fec6c5acSUday Bondhugula   LogicalResult
matchAndRewrite__anon71cd160d0711::TestSplitReturnType554fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
555fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
556fec6c5acSUday Bondhugula     // Check for a return of F32.
557fec6c5acSUday Bondhugula     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
558fec6c5acSUday Bondhugula       return failure();
559fec6c5acSUday Bondhugula 
560fec6c5acSUday Bondhugula     // Check if the first operation is a cast operation, if it is we use the
561fec6c5acSUday Bondhugula     // results directly.
562fec6c5acSUday Bondhugula     auto *defOp = operands[0].getDefiningOp();
563015192c6SRiver Riddle     if (auto packerOp =
564015192c6SRiver Riddle             llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
565fec6c5acSUday Bondhugula       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
566fec6c5acSUday Bondhugula       return success();
567fec6c5acSUday Bondhugula     }
568fec6c5acSUday Bondhugula 
569fec6c5acSUday Bondhugula     // Otherwise, fail to match.
570fec6c5acSUday Bondhugula     return failure();
571fec6c5acSUday Bondhugula   }
572fec6c5acSUday Bondhugula };
573fec6c5acSUday Bondhugula 
574fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
575fec6c5acSUday Bondhugula // Multi-Level Type-Conversion Rewrite Testing
576fec6c5acSUday Bondhugula struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
TestChangeProducerTypeI32ToF32__anon71cd160d0711::TestChangeProducerTypeI32ToF32577fec6c5acSUday Bondhugula   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
578fec6c5acSUday Bondhugula       : ConversionPattern("test.type_producer", 1, ctx) {}
579fec6c5acSUday Bondhugula   LogicalResult
matchAndRewrite__anon71cd160d0711::TestChangeProducerTypeI32ToF32580fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
581fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
582fec6c5acSUday Bondhugula     // If the type is I32, change the type to F32.
583fec6c5acSUday Bondhugula     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
584fec6c5acSUday Bondhugula       return failure();
585fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
586fec6c5acSUday Bondhugula     return success();
587fec6c5acSUday Bondhugula   }
588fec6c5acSUday Bondhugula };
589fec6c5acSUday Bondhugula struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
TestChangeProducerTypeF32ToF64__anon71cd160d0711::TestChangeProducerTypeF32ToF64590fec6c5acSUday Bondhugula   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
591fec6c5acSUday Bondhugula       : ConversionPattern("test.type_producer", 1, ctx) {}
592fec6c5acSUday Bondhugula   LogicalResult
matchAndRewrite__anon71cd160d0711::TestChangeProducerTypeF32ToF64593fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
594fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
595fec6c5acSUday Bondhugula     // If the type is F32, change the type to F64.
596fec6c5acSUday Bondhugula     if (!Type(*op->result_type_begin()).isF32())
597fec6c5acSUday Bondhugula       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
598fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
599fec6c5acSUday Bondhugula     return success();
600fec6c5acSUday Bondhugula   }
601fec6c5acSUday Bondhugula };
602fec6c5acSUday Bondhugula struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
TestChangeProducerTypeF32ToInvalid__anon71cd160d0711::TestChangeProducerTypeF32ToInvalid603fec6c5acSUday Bondhugula   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
604fec6c5acSUday Bondhugula       : ConversionPattern("test.type_producer", 10, ctx) {}
605fec6c5acSUday Bondhugula   LogicalResult
matchAndRewrite__anon71cd160d0711::TestChangeProducerTypeF32ToInvalid606fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
607fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
608fec6c5acSUday Bondhugula     // Always convert to B16, even though it is not a legal type. This tests
609fec6c5acSUday Bondhugula     // that values are unmapped correctly.
610fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
611fec6c5acSUday Bondhugula     return success();
612fec6c5acSUday Bondhugula   }
613fec6c5acSUday Bondhugula };
614fec6c5acSUday Bondhugula struct TestUpdateConsumerType : public ConversionPattern {
TestUpdateConsumerType__anon71cd160d0711::TestUpdateConsumerType615fec6c5acSUday Bondhugula   TestUpdateConsumerType(MLIRContext *ctx)
616fec6c5acSUday Bondhugula       : ConversionPattern("test.type_consumer", 1, ctx) {}
617fec6c5acSUday Bondhugula   LogicalResult
matchAndRewrite__anon71cd160d0711::TestUpdateConsumerType618fec6c5acSUday Bondhugula   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
619fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const final {
620fec6c5acSUday Bondhugula     // Verify that the incoming operand has been successfully remapped to F64.
621fec6c5acSUday Bondhugula     if (!operands[0].getType().isF64())
622fec6c5acSUday Bondhugula       return failure();
623fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
624fec6c5acSUday Bondhugula     return success();
625fec6c5acSUday Bondhugula   }
626fec6c5acSUday Bondhugula };
627fec6c5acSUday Bondhugula 
628fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
629fec6c5acSUday Bondhugula // Non-Root Replacement Rewrite Testing
630fec6c5acSUday Bondhugula /// This pattern generates an invalid operation, but replaces it before the
631fec6c5acSUday Bondhugula /// pattern is finished. This checks that we don't need to legalize the
632fec6c5acSUday Bondhugula /// temporary op.
633fec6c5acSUday Bondhugula struct TestNonRootReplacement : public RewritePattern {
TestNonRootReplacement__anon71cd160d0711::TestNonRootReplacement634fec6c5acSUday Bondhugula   TestNonRootReplacement(MLIRContext *ctx)
635fec6c5acSUday Bondhugula       : RewritePattern("test.replace_non_root", 1, ctx) {}
636fec6c5acSUday Bondhugula 
matchAndRewrite__anon71cd160d0711::TestNonRootReplacement637fec6c5acSUday Bondhugula   LogicalResult matchAndRewrite(Operation *op,
638fec6c5acSUday Bondhugula                                 PatternRewriter &rewriter) const final {
639fec6c5acSUday Bondhugula     auto resultType = *op->result_type_begin();
640fec6c5acSUday Bondhugula     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
641fec6c5acSUday Bondhugula     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
642fec6c5acSUday Bondhugula 
643fec6c5acSUday Bondhugula     rewriter.replaceOp(illegalOp, {legalOp});
644fec6c5acSUday Bondhugula     rewriter.replaceOp(op, {illegalOp});
645fec6c5acSUday Bondhugula     return success();
646fec6c5acSUday Bondhugula   }
647fec6c5acSUday Bondhugula };
648bd1ccfe6SRiver Riddle 
649bd1ccfe6SRiver Riddle //===----------------------------------------------------------------------===//
650bd1ccfe6SRiver Riddle // Recursive Rewrite Testing
651bd1ccfe6SRiver Riddle /// This pattern is applied to the same operation multiple times, but has a
652bd1ccfe6SRiver Riddle /// bounded recursion.
653bd1ccfe6SRiver Riddle struct TestBoundedRecursiveRewrite
654bd1ccfe6SRiver Riddle     : public OpRewritePattern<TestRecursiveRewriteOp> {
6552257e4a7SRiver Riddle   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
6562257e4a7SRiver Riddle 
initialize__anon71cd160d0711::TestBoundedRecursiveRewrite6572257e4a7SRiver Riddle   void initialize() {
658b99bd771SRiver Riddle     // The conversion target handles bounding the recursion of this pattern.
659b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
660b99bd771SRiver Riddle   }
661bd1ccfe6SRiver Riddle 
matchAndRewrite__anon71cd160d0711::TestBoundedRecursiveRewrite662bd1ccfe6SRiver Riddle   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
663bd1ccfe6SRiver Riddle                                 PatternRewriter &rewriter) const final {
664bd1ccfe6SRiver Riddle     // Decrement the depth of the op in-place.
665bd1ccfe6SRiver Riddle     rewriter.updateRootInPlace(op, [&] {
6666a994233SJacques Pienaar       op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
667bd1ccfe6SRiver Riddle     });
668bd1ccfe6SRiver Riddle     return success();
669bd1ccfe6SRiver Riddle   }
670bd1ccfe6SRiver Riddle };
6715d5df06aSAlex Zinenko 
6725d5df06aSAlex Zinenko struct TestNestedOpCreationUndoRewrite
6735d5df06aSAlex Zinenko     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
6745d5df06aSAlex Zinenko   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
6755d5df06aSAlex Zinenko 
matchAndRewrite__anon71cd160d0711::TestNestedOpCreationUndoRewrite6765d5df06aSAlex Zinenko   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
6775d5df06aSAlex Zinenko                                 PatternRewriter &rewriter) const final {
6785d5df06aSAlex Zinenko     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
6795d5df06aSAlex Zinenko     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
6805d5df06aSAlex Zinenko     return success();
6815d5df06aSAlex Zinenko   };
6825d5df06aSAlex Zinenko };
683a360a978SMehdi Amini 
684a360a978SMehdi Amini // This pattern matches `test.blackhole` and delete this op and its producer.
685a360a978SMehdi Amini struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
686a360a978SMehdi Amini   using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
687a360a978SMehdi Amini 
matchAndRewrite__anon71cd160d0711::TestReplaceEraseOp688a360a978SMehdi Amini   LogicalResult matchAndRewrite(BlackHoleOp op,
689a360a978SMehdi Amini                                 PatternRewriter &rewriter) const final {
690a360a978SMehdi Amini     Operation *producer = op.getOperand().getDefiningOp();
691a360a978SMehdi Amini     // Always erase the user before the producer, the framework should handle
692a360a978SMehdi Amini     // this correctly.
693a360a978SMehdi Amini     rewriter.eraseOp(op);
694a360a978SMehdi Amini     rewriter.eraseOp(producer);
695a360a978SMehdi Amini     return success();
696a360a978SMehdi Amini   };
697a360a978SMehdi Amini };
698ec03bbe8SVladislav Vinogradov 
699ec03bbe8SVladislav Vinogradov // This pattern replaces explicitly illegal op with explicitly legal op,
700ec03bbe8SVladislav Vinogradov // but in addition creates unregistered operation.
701ec03bbe8SVladislav Vinogradov struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
702ec03bbe8SVladislav Vinogradov   using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
703ec03bbe8SVladislav Vinogradov 
matchAndRewrite__anon71cd160d0711::TestCreateUnregisteredOp704ec03bbe8SVladislav Vinogradov   LogicalResult matchAndRewrite(ILLegalOpG op,
705ec03bbe8SVladislav Vinogradov                                 PatternRewriter &rewriter) const final {
706ec03bbe8SVladislav Vinogradov     IntegerAttr attr = rewriter.getI32IntegerAttr(0);
7078e123ca6SRiver Riddle     Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
708ec03bbe8SVladislav Vinogradov     rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
709ec03bbe8SVladislav Vinogradov     return success();
710ec03bbe8SVladislav Vinogradov   };
711ec03bbe8SVladislav Vinogradov };
712fec6c5acSUday Bondhugula } // namespace
713fec6c5acSUday Bondhugula 
714fec6c5acSUday Bondhugula namespace {
715fec6c5acSUday Bondhugula struct TestTypeConverter : public TypeConverter {
716fec6c5acSUday Bondhugula   using TypeConverter::TypeConverter;
TestTypeConverter__anon71cd160d0b11::TestTypeConverter7175c5dafc5SAlex Zinenko   TestTypeConverter() {
7185c5dafc5SAlex Zinenko     addConversion(convertType);
7194589dd92SRiver Riddle     addArgumentMaterialization(materializeCast);
7204589dd92SRiver Riddle     addSourceMaterialization(materializeCast);
7215c5dafc5SAlex Zinenko   }
722fec6c5acSUday Bondhugula 
convertType__anon71cd160d0b11::TestTypeConverter723fec6c5acSUday Bondhugula   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
724fec6c5acSUday Bondhugula     // Drop I16 types.
725fec6c5acSUday Bondhugula     if (t.isSignlessInteger(16))
726fec6c5acSUday Bondhugula       return success();
727fec6c5acSUday Bondhugula 
728fec6c5acSUday Bondhugula     // Convert I64 to F64.
729fec6c5acSUday Bondhugula     if (t.isSignlessInteger(64)) {
730fec6c5acSUday Bondhugula       results.push_back(FloatType::getF64(t.getContext()));
731fec6c5acSUday Bondhugula       return success();
732fec6c5acSUday Bondhugula     }
733fec6c5acSUday Bondhugula 
7345c5dafc5SAlex Zinenko     // Convert I42 to I43.
7355c5dafc5SAlex Zinenko     if (t.isInteger(42)) {
7361b97cdf8SRiver Riddle       results.push_back(IntegerType::get(t.getContext(), 43));
7375c5dafc5SAlex Zinenko       return success();
7385c5dafc5SAlex Zinenko     }
7395c5dafc5SAlex Zinenko 
740fec6c5acSUday Bondhugula     // Split F32 into F16,F16.
741fec6c5acSUday Bondhugula     if (t.isF32()) {
742fec6c5acSUday Bondhugula       results.assign(2, FloatType::getF16(t.getContext()));
743fec6c5acSUday Bondhugula       return success();
744fec6c5acSUday Bondhugula     }
745fec6c5acSUday Bondhugula 
746fec6c5acSUday Bondhugula     // Otherwise, convert the type directly.
747fec6c5acSUday Bondhugula     results.push_back(t);
748fec6c5acSUday Bondhugula     return success();
749fec6c5acSUday Bondhugula   }
750fec6c5acSUday Bondhugula 
7515c5dafc5SAlex Zinenko   /// Hook for materializing a conversion. This is necessary because we generate
7525c5dafc5SAlex Zinenko   /// 1->N type mappings.
materializeCast__anon71cd160d0b11::TestTypeConverter7534589dd92SRiver Riddle   static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
7544589dd92SRiver Riddle                                          ValueRange inputs, Location loc) {
7554589dd92SRiver Riddle     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
7565c5dafc5SAlex Zinenko   }
757fec6c5acSUday Bondhugula };
758fec6c5acSUday Bondhugula 
759fec6c5acSUday Bondhugula struct TestLegalizePatternDriver
76080aca1eaSRiver Riddle     : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d0b11::TestLegalizePatternDriver7615e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
7625e50dd04SRiver Riddle 
763b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-legalize-patterns"; }
getDescription__anon71cd160d0b11::TestLegalizePatternDriver764b5e22e6dSMehdi Amini   StringRef getDescription() const final {
765b5e22e6dSMehdi Amini     return "Run test dialect legalization patterns";
766b5e22e6dSMehdi Amini   }
767fec6c5acSUday Bondhugula   /// The mode of conversion to use with the driver.
768fec6c5acSUday Bondhugula   enum class ConversionMode { Analysis, Full, Partial };
769fec6c5acSUday Bondhugula 
TestLegalizePatternDriver__anon71cd160d0b11::TestLegalizePatternDriver770fec6c5acSUday Bondhugula   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
771fec6c5acSUday Bondhugula 
getDependentDialects__anon71cd160d0b11::TestLegalizePatternDriver772ec03bbe8SVladislav Vinogradov   void getDependentDialects(DialectRegistry &registry) const override {
77323aa5a74SRiver Riddle     registry.insert<func::FuncDialect>();
774ec03bbe8SVladislav Vinogradov   }
775ec03bbe8SVladislav Vinogradov 
runOnOperation__anon71cd160d0b11::TestLegalizePatternDriver776722f909fSRiver Riddle   void runOnOperation() override {
777fec6c5acSUday Bondhugula     TestTypeConverter converter;
778dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(&getContext());
7791d909c9aSChris Lattner     populateWithGenerated(patterns);
780dc4e913bSChris Lattner     patterns
781dc4e913bSChris Lattner         .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
782dc4e913bSChris Lattner              TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace,
783dc4e913bSChris Lattner              TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType,
784df48026bSAlex Zinenko              TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
785fec6c5acSUday Bondhugula              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
7865d5df06aSAlex Zinenko              TestNonRootReplacement, TestBoundedRecursiveRewrite,
787ec03bbe8SVladislav Vinogradov              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
788ec03bbe8SVladislav Vinogradov              TestCreateUnregisteredOp>(&getContext());
789dc4e913bSChris Lattner     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
79058ceae95SRiver Riddle     mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
79158ceae95SRiver Riddle         patterns, converter);
7923a506b31SChris Lattner     mlir::populateCallOpTypeConversionPattern(patterns, converter);
793fec6c5acSUday Bondhugula 
794fec6c5acSUday Bondhugula     // Define the conversion target used for the test.
795fec6c5acSUday Bondhugula     ConversionTarget target(getContext());
796973ddb7dSMehdi Amini     target.addLegalOp<ModuleOp>();
797ec03bbe8SVladislav Vinogradov     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
798f27f1e8cSAlex Zinenko                       TerminatorOp>();
799fec6c5acSUday Bondhugula     target
800fec6c5acSUday Bondhugula         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
801fec6c5acSUday Bondhugula     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
802fec6c5acSUday Bondhugula       // Don't allow F32 operands.
803fec6c5acSUday Bondhugula       return llvm::none_of(op.getOperandTypes(),
804fec6c5acSUday Bondhugula                            [](Type type) { return type.isF32(); });
805fec6c5acSUday Bondhugula     });
80658ceae95SRiver Riddle     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
8074a3460a7SRiver Riddle       return converter.isSignatureLegal(op.getFunctionType()) &&
8088d67d187SRiver Riddle              converter.isLegal(&op.getBody());
8098d67d187SRiver Riddle     });
81023aa5a74SRiver Riddle     target.addDynamicallyLegalOp<func::CallOp>(
81123aa5a74SRiver Riddle         [&](func::CallOp op) { return converter.isLegal(op); });
812fec6c5acSUday Bondhugula 
813a54f4eaeSMogball     // TestCreateUnregisteredOp creates `arith.constant` operation,
814ec03bbe8SVladislav Vinogradov     // which was not added to target intentionally to test
815ec03bbe8SVladislav Vinogradov     // correct error code from conversion driver.
816ec03bbe8SVladislav Vinogradov     target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
817ec03bbe8SVladislav Vinogradov 
818fec6c5acSUday Bondhugula     // Expect the type_producer/type_consumer operations to only operate on f64.
819fec6c5acSUday Bondhugula     target.addDynamicallyLegalOp<TestTypeProducerOp>(
820fec6c5acSUday Bondhugula         [](TestTypeProducerOp op) { return op.getType().isF64(); });
821fec6c5acSUday Bondhugula     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
822fec6c5acSUday Bondhugula       return op.getOperand().getType().isF64();
823fec6c5acSUday Bondhugula     });
824fec6c5acSUday Bondhugula 
825fec6c5acSUday Bondhugula     // Check support for marking certain operations as recursively legal.
82658ceae95SRiver Riddle     target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
827fec6c5acSUday Bondhugula       return static_cast<bool>(
828fec6c5acSUday Bondhugula           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
829fec6c5acSUday Bondhugula     });
830fec6c5acSUday Bondhugula 
831bd1ccfe6SRiver Riddle     // Mark the bound recursion operation as dynamically legal.
832bd1ccfe6SRiver Riddle     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
8336a994233SJacques Pienaar         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
834bd1ccfe6SRiver Riddle 
835fec6c5acSUday Bondhugula     // Handle a partial conversion.
836fec6c5acSUday Bondhugula     if (mode == ConversionMode::Partial) {
8378de482eaSLucy Fox       DenseSet<Operation *> unlegalizedOps;
838ec03bbe8SVladislav Vinogradov       if (failed(applyPartialConversion(
839ec03bbe8SVladislav Vinogradov               getOperation(), target, std::move(patterns), &unlegalizedOps))) {
840ec03bbe8SVladislav Vinogradov         getOperation()->emitRemark() << "applyPartialConversion failed";
841ec03bbe8SVladislav Vinogradov       }
8428de482eaSLucy Fox       // Emit remarks for each legalizable operation.
8438de482eaSLucy Fox       for (auto *op : unlegalizedOps)
8448de482eaSLucy Fox         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
845fec6c5acSUday Bondhugula       return;
846fec6c5acSUday Bondhugula     }
847fec6c5acSUday Bondhugula 
848fec6c5acSUday Bondhugula     // Handle a full conversion.
849fec6c5acSUday Bondhugula     if (mode == ConversionMode::Full) {
850fec6c5acSUday Bondhugula       // Check support for marking unknown operations as dynamically legal.
851fec6c5acSUday Bondhugula       target.markUnknownOpDynamicallyLegal([](Operation *op) {
852fec6c5acSUday Bondhugula         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
853fec6c5acSUday Bondhugula       });
854fec6c5acSUday Bondhugula 
855ec03bbe8SVladislav Vinogradov       if (failed(applyFullConversion(getOperation(), target,
856ec03bbe8SVladislav Vinogradov                                      std::move(patterns)))) {
857ec03bbe8SVladislav Vinogradov         getOperation()->emitRemark() << "applyFullConversion failed";
858ec03bbe8SVladislav Vinogradov       }
859fec6c5acSUday Bondhugula       return;
860fec6c5acSUday Bondhugula     }
861fec6c5acSUday Bondhugula 
862fec6c5acSUday Bondhugula     // Otherwise, handle an analysis conversion.
863fec6c5acSUday Bondhugula     assert(mode == ConversionMode::Analysis);
864fec6c5acSUday Bondhugula 
865fec6c5acSUday Bondhugula     // Analyze the convertible operations.
866fec6c5acSUday Bondhugula     DenseSet<Operation *> legalizedOps;
8673fffffa8SRiver Riddle     if (failed(applyAnalysisConversion(getOperation(), target,
8683fffffa8SRiver Riddle                                        std::move(patterns), legalizedOps)))
869fec6c5acSUday Bondhugula       return signalPassFailure();
870fec6c5acSUday Bondhugula 
871fec6c5acSUday Bondhugula     // Emit remarks for each legalizable operation.
872fec6c5acSUday Bondhugula     for (auto *op : legalizedOps)
873fec6c5acSUday Bondhugula       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
874fec6c5acSUday Bondhugula   }
875fec6c5acSUday Bondhugula 
876fec6c5acSUday Bondhugula   /// The mode of conversion to use.
877fec6c5acSUday Bondhugula   ConversionMode mode;
878fec6c5acSUday Bondhugula };
879be0a7e9fSMehdi Amini } // namespace
880fec6c5acSUday Bondhugula 
881fec6c5acSUday Bondhugula static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
882fec6c5acSUday Bondhugula     legalizerConversionMode(
883fec6c5acSUday Bondhugula         "test-legalize-mode",
884fec6c5acSUday Bondhugula         llvm::cl::desc("The legalization mode to use with the test driver"),
885fec6c5acSUday Bondhugula         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
886fec6c5acSUday Bondhugula         llvm::cl::values(
887fec6c5acSUday Bondhugula             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
888fec6c5acSUday Bondhugula                        "analysis", "Perform an analysis conversion"),
889fec6c5acSUday Bondhugula             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
890fec6c5acSUday Bondhugula                        "Perform a full conversion"),
891fec6c5acSUday Bondhugula             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
892fec6c5acSUday Bondhugula                        "partial", "Perform a partial conversion")));
893fec6c5acSUday Bondhugula 
894fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
895fec6c5acSUday Bondhugula // ConversionPatternRewriter::getRemappedValue testing. This method is used
8965aacce3dSKazuaki Ishizaki // to get the remapped value of an original value that was replaced using
897fec6c5acSUday Bondhugula // ConversionPatternRewriter.
898fec6c5acSUday Bondhugula namespace {
899015192c6SRiver Riddle struct TestRemapValueTypeConverter : public TypeConverter {
900015192c6SRiver Riddle   using TypeConverter::TypeConverter;
901015192c6SRiver Riddle 
TestRemapValueTypeConverter__anon71cd160d1611::TestRemapValueTypeConverter902015192c6SRiver Riddle   TestRemapValueTypeConverter() {
903015192c6SRiver Riddle     addConversion(
904015192c6SRiver Riddle         [](Float32Type type) { return Float64Type::get(type.getContext()); });
905015192c6SRiver Riddle     addConversion([](Type type) { return type; });
906015192c6SRiver Riddle   }
907015192c6SRiver Riddle };
908015192c6SRiver Riddle 
909fec6c5acSUday Bondhugula /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
910fec6c5acSUday Bondhugula /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
911fec6c5acSUday Bondhugula /// operand twice.
912fec6c5acSUday Bondhugula ///
913fec6c5acSUday Bondhugula /// Example:
914fec6c5acSUday Bondhugula ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
915fec6c5acSUday Bondhugula /// is replaced with:
916fec6c5acSUday Bondhugula ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
917fec6c5acSUday Bondhugula struct OneVResOneVOperandOp1Converter
918fec6c5acSUday Bondhugula     : public OpConversionPattern<OneVResOneVOperandOp1> {
919fec6c5acSUday Bondhugula   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
920fec6c5acSUday Bondhugula 
921fec6c5acSUday Bondhugula   LogicalResult
matchAndRewrite__anon71cd160d1611::OneVResOneVOperandOp1Converter922ef976337SRiver Riddle   matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
923fec6c5acSUday Bondhugula                   ConversionPatternRewriter &rewriter) const override {
924fec6c5acSUday Bondhugula     auto origOps = op.getOperands();
925fec6c5acSUday Bondhugula     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
926fec6c5acSUday Bondhugula            "One operand expected");
927fec6c5acSUday Bondhugula     Value origOp = *origOps.begin();
928fec6c5acSUday Bondhugula     SmallVector<Value, 2> remappedOperands;
929fec6c5acSUday Bondhugula     // Replicate the remapped original operand twice. Note that we don't used
930fec6c5acSUday Bondhugula     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
931fec6c5acSUday Bondhugula     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
932fec6c5acSUday Bondhugula     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
933fec6c5acSUday Bondhugula 
934fec6c5acSUday Bondhugula     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
935fec6c5acSUday Bondhugula                                                        remappedOperands);
936fec6c5acSUday Bondhugula     return success();
937fec6c5acSUday Bondhugula   }
938fec6c5acSUday Bondhugula };
939fec6c5acSUday Bondhugula 
940015192c6SRiver Riddle /// A rewriter pattern that tests that blocks can be merged.
941015192c6SRiver Riddle struct TestRemapValueInRegion
942015192c6SRiver Riddle     : public OpConversionPattern<TestRemappedValueRegionOp> {
943015192c6SRiver Riddle   using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
944015192c6SRiver Riddle 
945015192c6SRiver Riddle   LogicalResult
matchAndRewrite__anon71cd160d1611::TestRemapValueInRegion946015192c6SRiver Riddle   matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
947015192c6SRiver Riddle                   ConversionPatternRewriter &rewriter) const final {
948015192c6SRiver Riddle     Block &block = op.getBody().front();
949015192c6SRiver Riddle     Operation *terminator = block.getTerminator();
950015192c6SRiver Riddle 
951015192c6SRiver Riddle     // Merge the block into the parent region.
952015192c6SRiver Riddle     Block *parentBlock = op->getBlock();
953015192c6SRiver Riddle     Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
954015192c6SRiver Riddle     rewriter.mergeBlocks(&block, parentBlock, ValueRange());
955015192c6SRiver Riddle     rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
956015192c6SRiver Riddle 
957015192c6SRiver Riddle     // Replace the results of this operation with the remapped terminator
958015192c6SRiver Riddle     // values.
959015192c6SRiver Riddle     SmallVector<Value> terminatorOperands;
960015192c6SRiver Riddle     if (failed(rewriter.getRemappedValues(terminator->getOperands(),
961015192c6SRiver Riddle                                           terminatorOperands)))
962015192c6SRiver Riddle       return failure();
963015192c6SRiver Riddle 
964015192c6SRiver Riddle     rewriter.eraseOp(terminator);
965015192c6SRiver Riddle     rewriter.replaceOp(op, terminatorOperands);
966015192c6SRiver Riddle     return success();
967015192c6SRiver Riddle   }
968015192c6SRiver Riddle };
969015192c6SRiver Riddle 
97080aca1eaSRiver Riddle struct TestRemappedValue
97158ceae95SRiver Riddle     : public mlir::PassWrapper<TestRemappedValue, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d1611::TestRemappedValue9725e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
9735e50dd04SRiver Riddle 
974b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-remapped-value"; }
getDescription__anon71cd160d1611::TestRemappedValue975b5e22e6dSMehdi Amini   StringRef getDescription() const final {
976b5e22e6dSMehdi Amini     return "Test public remapped value mechanism in ConversionPatternRewriter";
977b5e22e6dSMehdi Amini   }
runOnOperation__anon71cd160d1611::TestRemappedValue97841574554SRiver Riddle   void runOnOperation() override {
979015192c6SRiver Riddle     TestRemapValueTypeConverter typeConverter;
980015192c6SRiver Riddle 
981dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(&getContext());
982dc4e913bSChris Lattner     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
983015192c6SRiver Riddle     patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
984015192c6SRiver Riddle         &getContext());
985015192c6SRiver Riddle     patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
986fec6c5acSUday Bondhugula 
987fec6c5acSUday Bondhugula     mlir::ConversionTarget target(getContext());
98858ceae95SRiver Riddle     target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
989015192c6SRiver Riddle 
990015192c6SRiver Riddle     // Expect the type_producer/type_consumer operations to only operate on f64.
991015192c6SRiver Riddle     target.addDynamicallyLegalOp<TestTypeProducerOp>(
992015192c6SRiver Riddle         [](TestTypeProducerOp op) { return op.getType().isF64(); });
993015192c6SRiver Riddle     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
994015192c6SRiver Riddle       return op.getOperand().getType().isF64();
995015192c6SRiver Riddle     });
996015192c6SRiver Riddle 
997fec6c5acSUday Bondhugula     // We make OneVResOneVOperandOp1 legal only when it has more that one
998fec6c5acSUday Bondhugula     // operand. This will trigger the conversion that will replace one-operand
999fec6c5acSUday Bondhugula     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1000fec6c5acSUday Bondhugula     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1001015192c6SRiver Riddle         [](Operation *op) { return op->getNumOperands() > 1; });
1002fec6c5acSUday Bondhugula 
100341574554SRiver Riddle     if (failed(mlir::applyFullConversion(getOperation(), target,
10043fffffa8SRiver Riddle                                          std::move(patterns)))) {
1005fec6c5acSUday Bondhugula       signalPassFailure();
1006fec6c5acSUday Bondhugula     }
1007fec6c5acSUday Bondhugula   }
1008fec6c5acSUday Bondhugula };
1009be0a7e9fSMehdi Amini } // namespace
1010fec6c5acSUday Bondhugula 
101180d7ac3bSRiver Riddle //===----------------------------------------------------------------------===//
101280d7ac3bSRiver Riddle // Test patterns without a specific root operation kind
101380d7ac3bSRiver Riddle //===----------------------------------------------------------------------===//
101480d7ac3bSRiver Riddle 
101580d7ac3bSRiver Riddle namespace {
101680d7ac3bSRiver Riddle /// This pattern matches and removes any operation in the test dialect.
101780d7ac3bSRiver Riddle struct RemoveTestDialectOps : public RewritePattern {
RemoveTestDialectOps__anon71cd160d1c11::RemoveTestDialectOps101876f3c2f3SRiver Riddle   RemoveTestDialectOps(MLIRContext *context)
101976f3c2f3SRiver Riddle       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
102080d7ac3bSRiver Riddle 
matchAndRewrite__anon71cd160d1c11::RemoveTestDialectOps102180d7ac3bSRiver Riddle   LogicalResult matchAndRewrite(Operation *op,
102280d7ac3bSRiver Riddle                                 PatternRewriter &rewriter) const override {
102380d7ac3bSRiver Riddle     if (!isa<TestDialect>(op->getDialect()))
102480d7ac3bSRiver Riddle       return failure();
102580d7ac3bSRiver Riddle     rewriter.eraseOp(op);
102680d7ac3bSRiver Riddle     return success();
102780d7ac3bSRiver Riddle   }
102880d7ac3bSRiver Riddle };
102980d7ac3bSRiver Riddle 
103080d7ac3bSRiver Riddle struct TestUnknownRootOpDriver
103158ceae95SRiver Riddle     : public mlir::PassWrapper<TestUnknownRootOpDriver,
103258ceae95SRiver Riddle                                OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d1c11::TestUnknownRootOpDriver10335e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
10345e50dd04SRiver Riddle 
1035b5e22e6dSMehdi Amini   StringRef getArgument() const final {
1036b5e22e6dSMehdi Amini     return "test-legalize-unknown-root-patterns";
1037b5e22e6dSMehdi Amini   }
getDescription__anon71cd160d1c11::TestUnknownRootOpDriver1038b5e22e6dSMehdi Amini   StringRef getDescription() const final {
1039b5e22e6dSMehdi Amini     return "Test public remapped value mechanism in ConversionPatternRewriter";
1040b5e22e6dSMehdi Amini   }
runOnOperation__anon71cd160d1c11::TestUnknownRootOpDriver104141574554SRiver Riddle   void runOnOperation() override {
1042dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(&getContext());
104376f3c2f3SRiver Riddle     patterns.add<RemoveTestDialectOps>(&getContext());
104480d7ac3bSRiver Riddle 
104580d7ac3bSRiver Riddle     mlir::ConversionTarget target(getContext());
104680d7ac3bSRiver Riddle     target.addIllegalDialect<TestDialect>();
104741574554SRiver Riddle     if (failed(applyPartialConversion(getOperation(), target,
104841574554SRiver Riddle                                       std::move(patterns))))
104980d7ac3bSRiver Riddle       signalPassFailure();
105080d7ac3bSRiver Riddle   }
105180d7ac3bSRiver Riddle };
1052be0a7e9fSMehdi Amini } // namespace
105380d7ac3bSRiver Riddle 
10544589dd92SRiver Riddle //===----------------------------------------------------------------------===//
105588bc24a7SMathieu Fehr // Test patterns that uses operations and types defined at runtime
105688bc24a7SMathieu Fehr //===----------------------------------------------------------------------===//
105788bc24a7SMathieu Fehr 
105888bc24a7SMathieu Fehr namespace {
105988bc24a7SMathieu Fehr /// This pattern matches dynamic operations 'test.one_operand_two_results' and
106088bc24a7SMathieu Fehr /// replace them with dynamic operations 'test.generic_dynamic_op'.
106188bc24a7SMathieu Fehr struct RewriteDynamicOp : public RewritePattern {
RewriteDynamicOp__anon71cd160d1d11::RewriteDynamicOp106288bc24a7SMathieu Fehr   RewriteDynamicOp(MLIRContext *context)
106388bc24a7SMathieu Fehr       : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
106488bc24a7SMathieu Fehr                        context) {}
106588bc24a7SMathieu Fehr 
matchAndRewrite__anon71cd160d1d11::RewriteDynamicOp106688bc24a7SMathieu Fehr   LogicalResult matchAndRewrite(Operation *op,
106788bc24a7SMathieu Fehr                                 PatternRewriter &rewriter) const override {
106888bc24a7SMathieu Fehr     assert(op->getName().getStringRef() ==
106988bc24a7SMathieu Fehr                "test.dynamic_one_operand_two_results" &&
107088bc24a7SMathieu Fehr            "rewrite pattern should only match operations with the right name");
107188bc24a7SMathieu Fehr 
107288bc24a7SMathieu Fehr     OperationState state(op->getLoc(), "test.dynamic_generic",
107388bc24a7SMathieu Fehr                          op->getOperands(), op->getResultTypes(),
107488bc24a7SMathieu Fehr                          op->getAttrs());
107588bc24a7SMathieu Fehr     auto *newOp = rewriter.create(state);
107688bc24a7SMathieu Fehr     rewriter.replaceOp(op, newOp->getResults());
107788bc24a7SMathieu Fehr     return success();
107888bc24a7SMathieu Fehr   }
107988bc24a7SMathieu Fehr };
108088bc24a7SMathieu Fehr 
108188bc24a7SMathieu Fehr struct TestRewriteDynamicOpDriver
108288bc24a7SMathieu Fehr     : public PassWrapper<TestRewriteDynamicOpDriver,
108388bc24a7SMathieu Fehr                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d1d11::TestRewriteDynamicOpDriver108488bc24a7SMathieu Fehr   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
108588bc24a7SMathieu Fehr 
108688bc24a7SMathieu Fehr   void getDependentDialects(DialectRegistry &registry) const override {
108788bc24a7SMathieu Fehr     registry.insert<TestDialect>();
108888bc24a7SMathieu Fehr   }
getArgument__anon71cd160d1d11::TestRewriteDynamicOpDriver108988bc24a7SMathieu Fehr   StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
getDescription__anon71cd160d1d11::TestRewriteDynamicOpDriver109088bc24a7SMathieu Fehr   StringRef getDescription() const final {
109188bc24a7SMathieu Fehr     return "Test rewritting on dynamic operations";
109288bc24a7SMathieu Fehr   }
runOnOperation__anon71cd160d1d11::TestRewriteDynamicOpDriver109388bc24a7SMathieu Fehr   void runOnOperation() override {
109488bc24a7SMathieu Fehr     RewritePatternSet patterns(&getContext());
109588bc24a7SMathieu Fehr     patterns.add<RewriteDynamicOp>(&getContext());
109688bc24a7SMathieu Fehr 
109788bc24a7SMathieu Fehr     ConversionTarget target(getContext());
109888bc24a7SMathieu Fehr     target.addIllegalOp(
109988bc24a7SMathieu Fehr         OperationName("test.dynamic_one_operand_two_results", &getContext()));
110088bc24a7SMathieu Fehr     target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
110188bc24a7SMathieu Fehr     if (failed(applyPartialConversion(getOperation(), target,
110288bc24a7SMathieu Fehr                                       std::move(patterns))))
110388bc24a7SMathieu Fehr       signalPassFailure();
110488bc24a7SMathieu Fehr   }
110588bc24a7SMathieu Fehr };
110688bc24a7SMathieu Fehr } // end anonymous namespace
110788bc24a7SMathieu Fehr 
110888bc24a7SMathieu Fehr //===----------------------------------------------------------------------===//
11094589dd92SRiver Riddle // Test type conversions
11104589dd92SRiver Riddle //===----------------------------------------------------------------------===//
11114589dd92SRiver Riddle 
11124589dd92SRiver Riddle namespace {
11134589dd92SRiver Riddle struct TestTypeConversionProducer
11144589dd92SRiver Riddle     : public OpConversionPattern<TestTypeProducerOp> {
11154589dd92SRiver Riddle   using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
11164589dd92SRiver Riddle   LogicalResult
matchAndRewrite__anon71cd160d1e11::TestTypeConversionProducer1117ef976337SRiver Riddle   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
11184589dd92SRiver Riddle                   ConversionPatternRewriter &rewriter) const final {
11194589dd92SRiver Riddle     Type resultType = op.getType();
11209c5982efSAlex Zinenko     Type convertedType = getTypeConverter()
11219c5982efSAlex Zinenko                              ? getTypeConverter()->convertType(resultType)
11229c5982efSAlex Zinenko                              : resultType;
11234589dd92SRiver Riddle     if (resultType.isa<FloatType>())
11244589dd92SRiver Riddle       resultType = rewriter.getF64Type();
11254589dd92SRiver Riddle     else if (resultType.isInteger(16))
11264589dd92SRiver Riddle       resultType = rewriter.getIntegerType(64);
11279c5982efSAlex Zinenko     else if (resultType.isa<test::TestRecursiveType>() &&
11289c5982efSAlex Zinenko              convertedType != resultType)
11299c5982efSAlex Zinenko       resultType = convertedType;
11304589dd92SRiver Riddle     else
11314589dd92SRiver Riddle       return failure();
11324589dd92SRiver Riddle 
11334589dd92SRiver Riddle     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
11344589dd92SRiver Riddle     return success();
11354589dd92SRiver Riddle   }
11364589dd92SRiver Riddle };
11374589dd92SRiver Riddle 
11380409eb28SAlex Zinenko /// Call signature conversion and then fail the rewrite to trigger the undo
11390409eb28SAlex Zinenko /// mechanism.
11400409eb28SAlex Zinenko struct TestSignatureConversionUndo
11410409eb28SAlex Zinenko     : public OpConversionPattern<TestSignatureConversionUndoOp> {
11420409eb28SAlex Zinenko   using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
11430409eb28SAlex Zinenko 
11440409eb28SAlex Zinenko   LogicalResult
matchAndRewrite__anon71cd160d1e11::TestSignatureConversionUndo1145ef976337SRiver Riddle   matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
11460409eb28SAlex Zinenko                   ConversionPatternRewriter &rewriter) const final {
11470409eb28SAlex Zinenko     (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
11480409eb28SAlex Zinenko     return failure();
11490409eb28SAlex Zinenko   }
11500409eb28SAlex Zinenko };
11510409eb28SAlex Zinenko 
11524070f305SRiver Riddle /// Call signature conversion without providing a type converter to handle
11534070f305SRiver Riddle /// materializations.
11544070f305SRiver Riddle struct TestTestSignatureConversionNoConverter
11554070f305SRiver Riddle     : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
TestTestSignatureConversionNoConverter__anon71cd160d1e11::TestTestSignatureConversionNoConverter11564070f305SRiver Riddle   TestTestSignatureConversionNoConverter(TypeConverter &converter,
11574070f305SRiver Riddle                                          MLIRContext *context)
11584070f305SRiver Riddle       : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
11594070f305SRiver Riddle         converter(converter) {}
11604070f305SRiver Riddle 
11614070f305SRiver Riddle   LogicalResult
matchAndRewrite__anon71cd160d1e11::TestTestSignatureConversionNoConverter11624070f305SRiver Riddle   matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
11634070f305SRiver Riddle                   ConversionPatternRewriter &rewriter) const final {
11644070f305SRiver Riddle     Region &region = op->getRegion(0);
11654070f305SRiver Riddle     Block *entry = &region.front();
11664070f305SRiver Riddle 
11674070f305SRiver Riddle     // Convert the original entry arguments.
11684070f305SRiver Riddle     TypeConverter::SignatureConversion result(entry->getNumArguments());
11694070f305SRiver Riddle     if (failed(
11704070f305SRiver Riddle             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
11714070f305SRiver Riddle       return failure();
11724070f305SRiver Riddle     rewriter.updateRootInPlace(
11734070f305SRiver Riddle         op, [&] { rewriter.applySignatureConversion(&region, result); });
11744070f305SRiver Riddle     return success();
11754070f305SRiver Riddle   }
11764070f305SRiver Riddle 
11774070f305SRiver Riddle   TypeConverter &converter;
11784070f305SRiver Riddle };
11794070f305SRiver Riddle 
11800409eb28SAlex Zinenko /// Just forward the operands to the root op. This is essentially a no-op
11810409eb28SAlex Zinenko /// pattern that is used to trigger target materialization.
11820409eb28SAlex Zinenko struct TestTypeConsumerForward
11830409eb28SAlex Zinenko     : public OpConversionPattern<TestTypeConsumerOp> {
11840409eb28SAlex Zinenko   using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
11850409eb28SAlex Zinenko 
11860409eb28SAlex Zinenko   LogicalResult
matchAndRewrite__anon71cd160d1e11::TestTypeConsumerForward1187ef976337SRiver Riddle   matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
11880409eb28SAlex Zinenko                   ConversionPatternRewriter &rewriter) const final {
1189ef976337SRiver Riddle     rewriter.updateRootInPlace(op,
1190ef976337SRiver Riddle                                [&] { op->setOperands(adaptor.getOperands()); });
11910409eb28SAlex Zinenko     return success();
11920409eb28SAlex Zinenko   }
11930409eb28SAlex Zinenko };
11940409eb28SAlex Zinenko 
11955b91060dSAlex Zinenko struct TestTypeConversionAnotherProducer
11965b91060dSAlex Zinenko     : public OpRewritePattern<TestAnotherTypeProducerOp> {
11975b91060dSAlex Zinenko   using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
11985b91060dSAlex Zinenko 
matchAndRewrite__anon71cd160d1e11::TestTypeConversionAnotherProducer11995b91060dSAlex Zinenko   LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
12005b91060dSAlex Zinenko                                 PatternRewriter &rewriter) const final {
12015b91060dSAlex Zinenko     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
12025b91060dSAlex Zinenko     return success();
12035b91060dSAlex Zinenko   }
12045b91060dSAlex Zinenko };
12055b91060dSAlex Zinenko 
12064589dd92SRiver Riddle struct TestTypeConversionDriver
12074589dd92SRiver Riddle     : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d1e11::TestTypeConversionDriver12085e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
12095e50dd04SRiver Riddle 
1210f9dc2b70SMehdi Amini   void getDependentDialects(DialectRegistry &registry) const override {
1211f9dc2b70SMehdi Amini     registry.insert<TestDialect>();
1212f9dc2b70SMehdi Amini   }
getArgument__anon71cd160d1e11::TestTypeConversionDriver1213b5e22e6dSMehdi Amini   StringRef getArgument() const final {
1214b5e22e6dSMehdi Amini     return "test-legalize-type-conversion";
1215b5e22e6dSMehdi Amini   }
getDescription__anon71cd160d1e11::TestTypeConversionDriver1216b5e22e6dSMehdi Amini   StringRef getDescription() const final {
1217b5e22e6dSMehdi Amini     return "Test various type conversion functionalities in DialectConversion";
1218b5e22e6dSMehdi Amini   }
1219f9dc2b70SMehdi Amini 
runOnOperation__anon71cd160d1e11::TestTypeConversionDriver12204589dd92SRiver Riddle   void runOnOperation() override {
12214589dd92SRiver Riddle     // Initialize the type converter.
12224589dd92SRiver Riddle     TypeConverter converter;
12234589dd92SRiver Riddle 
12244589dd92SRiver Riddle     /// Add the legal set of type conversions.
12254589dd92SRiver Riddle     converter.addConversion([](Type type) -> Type {
12264589dd92SRiver Riddle       // Treat F64 as legal.
12274589dd92SRiver Riddle       if (type.isF64())
12284589dd92SRiver Riddle         return type;
12294589dd92SRiver Riddle       // Allow converting BF16/F16/F32 to F64.
12304589dd92SRiver Riddle       if (type.isBF16() || type.isF16() || type.isF32())
12314589dd92SRiver Riddle         return FloatType::getF64(type.getContext());
12324589dd92SRiver Riddle       // Otherwise, the type is illegal.
12334589dd92SRiver Riddle       return nullptr;
12344589dd92SRiver Riddle     });
12354589dd92SRiver Riddle     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
12364589dd92SRiver Riddle       // Drop all integer types.
12374589dd92SRiver Riddle       return success();
12384589dd92SRiver Riddle     });
12399c5982efSAlex Zinenko     converter.addConversion(
12409c5982efSAlex Zinenko         // Convert a recursive self-referring type into a non-self-referring
12419c5982efSAlex Zinenko         // type named "outer_converted_type" that contains a SimpleAType.
12429c5982efSAlex Zinenko         [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results,
12439c5982efSAlex Zinenko             ArrayRef<Type> callStack) -> Optional<LogicalResult> {
12449c5982efSAlex Zinenko           // If the type is already converted, return it to indicate that it is
12459c5982efSAlex Zinenko           // legal.
12469c5982efSAlex Zinenko           if (type.getName() == "outer_converted_type") {
12479c5982efSAlex Zinenko             results.push_back(type);
12489c5982efSAlex Zinenko             return success();
12499c5982efSAlex Zinenko           }
12509c5982efSAlex Zinenko 
12519c5982efSAlex Zinenko           // If the type is on the call stack more than once (it is there at
12529c5982efSAlex Zinenko           // least once because of the _current_ call, which is always the last
12539c5982efSAlex Zinenko           // element on the stack), we've hit the recursive case. Just return
12549c5982efSAlex Zinenko           // SimpleAType here to create a non-recursive type as a result.
12559c5982efSAlex Zinenko           if (llvm::is_contained(callStack.drop_back(), type)) {
12569c5982efSAlex Zinenko             results.push_back(test::SimpleAType::get(type.getContext()));
12579c5982efSAlex Zinenko             return success();
12589c5982efSAlex Zinenko           }
12599c5982efSAlex Zinenko 
12609c5982efSAlex Zinenko           // Convert the body recursively.
12619c5982efSAlex Zinenko           auto result = test::TestRecursiveType::get(type.getContext(),
12629c5982efSAlex Zinenko                                                      "outer_converted_type");
12639c5982efSAlex Zinenko           if (failed(result.setBody(converter.convertType(type.getBody()))))
12649c5982efSAlex Zinenko             return failure();
12659c5982efSAlex Zinenko           results.push_back(result);
12669c5982efSAlex Zinenko           return success();
12679c5982efSAlex Zinenko         });
12684589dd92SRiver Riddle 
12694589dd92SRiver Riddle     /// Add the legal set of type materializations.
12704589dd92SRiver Riddle     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
12714589dd92SRiver Riddle                                           ValueRange inputs,
12724589dd92SRiver Riddle                                           Location loc) -> Value {
12734589dd92SRiver Riddle       // Allow casting from F64 back to F32.
12744589dd92SRiver Riddle       if (!resultType.isF16() && inputs.size() == 1 &&
12754589dd92SRiver Riddle           inputs[0].getType().isF64())
12764589dd92SRiver Riddle         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
12774589dd92SRiver Riddle       // Allow producing an i32 or i64 from nothing.
12784589dd92SRiver Riddle       if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
12794589dd92SRiver Riddle           inputs.empty())
12804589dd92SRiver Riddle         return builder.create<TestTypeProducerOp>(loc, resultType);
12814589dd92SRiver Riddle       // Allow producing an i64 from an integer.
12824589dd92SRiver Riddle       if (resultType.isa<IntegerType>() && inputs.size() == 1 &&
12834589dd92SRiver Riddle           inputs[0].getType().isa<IntegerType>())
12844589dd92SRiver Riddle         return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
12854589dd92SRiver Riddle       // Otherwise, fail.
12864589dd92SRiver Riddle       return nullptr;
12874589dd92SRiver Riddle     });
12884589dd92SRiver Riddle 
12894589dd92SRiver Riddle     // Initialize the conversion target.
12904589dd92SRiver Riddle     mlir::ConversionTarget target(getContext());
12914589dd92SRiver Riddle     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
12929c5982efSAlex Zinenko       auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>();
12939c5982efSAlex Zinenko       return op.getType().isF64() || op.getType().isInteger(64) ||
12949c5982efSAlex Zinenko              (recursiveType &&
12959c5982efSAlex Zinenko               recursiveType.getName() == "outer_converted_type");
12964589dd92SRiver Riddle     });
129758ceae95SRiver Riddle     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
12984a3460a7SRiver Riddle       return converter.isSignatureLegal(op.getFunctionType()) &&
12994589dd92SRiver Riddle              converter.isLegal(&op.getBody());
13004589dd92SRiver Riddle     });
13014589dd92SRiver Riddle     target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
13024589dd92SRiver Riddle       // Allow casts from F64 to F32.
13034589dd92SRiver Riddle       return (*op.operand_type_begin()).isF64() && op.getType().isF32();
13044589dd92SRiver Riddle     });
13054070f305SRiver Riddle     target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
13064070f305SRiver Riddle         [&](TestSignatureConversionNoConverterOp op) {
13074070f305SRiver Riddle           return converter.isLegal(op.getRegion().front().getArgumentTypes());
13084070f305SRiver Riddle         });
13094589dd92SRiver Riddle 
13104589dd92SRiver Riddle     // Initialize the set of rewrite patterns.
1311dc4e913bSChris Lattner     RewritePatternSet patterns(&getContext());
1312dc4e913bSChris Lattner     patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
13134070f305SRiver Riddle                  TestSignatureConversionUndo,
13144070f305SRiver Riddle                  TestTestSignatureConversionNoConverter>(converter,
13154070f305SRiver Riddle                                                          &getContext());
1316dc4e913bSChris Lattner     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
131758ceae95SRiver Riddle     mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
131858ceae95SRiver Riddle         patterns, converter);
13194589dd92SRiver Riddle 
13203fffffa8SRiver Riddle     if (failed(applyPartialConversion(getOperation(), target,
13213fffffa8SRiver Riddle                                       std::move(patterns))))
13224589dd92SRiver Riddle       signalPassFailure();
13234589dd92SRiver Riddle   }
13244589dd92SRiver Riddle };
1325be0a7e9fSMehdi Amini } // namespace
13264589dd92SRiver Riddle 
1327c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
1328d4a53f3bSAlex Zinenko // Test Target Materialization With No Uses
1329d4a53f3bSAlex Zinenko //===----------------------------------------------------------------------===//
1330d4a53f3bSAlex Zinenko 
1331d4a53f3bSAlex Zinenko namespace {
1332d4a53f3bSAlex Zinenko struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1333d4a53f3bSAlex Zinenko   using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1334d4a53f3bSAlex Zinenko 
1335d4a53f3bSAlex Zinenko   LogicalResult
matchAndRewrite__anon71cd160d2911::ForwardOperandPattern1336d4a53f3bSAlex Zinenko   matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1337d4a53f3bSAlex Zinenko                   ConversionPatternRewriter &rewriter) const final {
1338d4a53f3bSAlex Zinenko     rewriter.replaceOp(op, adaptor.getOperands());
1339d4a53f3bSAlex Zinenko     return success();
1340d4a53f3bSAlex Zinenko   }
1341d4a53f3bSAlex Zinenko };
1342d4a53f3bSAlex Zinenko 
1343d4a53f3bSAlex Zinenko struct TestTargetMaterializationWithNoUses
1344d4a53f3bSAlex Zinenko     : public PassWrapper<TestTargetMaterializationWithNoUses,
1345d4a53f3bSAlex Zinenko                          OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d2911::TestTargetMaterializationWithNoUses13465e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
13475e50dd04SRiver Riddle       TestTargetMaterializationWithNoUses)
13485e50dd04SRiver Riddle 
1349d4a53f3bSAlex Zinenko   StringRef getArgument() const final {
1350d4a53f3bSAlex Zinenko     return "test-target-materialization-with-no-uses";
1351d4a53f3bSAlex Zinenko   }
getDescription__anon71cd160d2911::TestTargetMaterializationWithNoUses1352d4a53f3bSAlex Zinenko   StringRef getDescription() const final {
1353d4a53f3bSAlex Zinenko     return "Test a special case of target materialization in DialectConversion";
1354d4a53f3bSAlex Zinenko   }
1355d4a53f3bSAlex Zinenko 
runOnOperation__anon71cd160d2911::TestTargetMaterializationWithNoUses1356d4a53f3bSAlex Zinenko   void runOnOperation() override {
1357d4a53f3bSAlex Zinenko     TypeConverter converter;
1358d4a53f3bSAlex Zinenko     converter.addConversion([](Type t) { return t; });
1359d4a53f3bSAlex Zinenko     converter.addConversion([](IntegerType intTy) -> Type {
1360d4a53f3bSAlex Zinenko       if (intTy.getWidth() == 16)
1361d4a53f3bSAlex Zinenko         return IntegerType::get(intTy.getContext(), 64);
1362d4a53f3bSAlex Zinenko       return intTy;
1363d4a53f3bSAlex Zinenko     });
1364d4a53f3bSAlex Zinenko     converter.addTargetMaterialization(
1365d4a53f3bSAlex Zinenko         [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1366d4a53f3bSAlex Zinenko           return builder.create<TestCastOp>(loc, type, inputs).getResult();
1367d4a53f3bSAlex Zinenko         });
1368d4a53f3bSAlex Zinenko 
1369d4a53f3bSAlex Zinenko     ConversionTarget target(getContext());
1370d4a53f3bSAlex Zinenko     target.addIllegalOp<TestTypeChangerOp>();
1371d4a53f3bSAlex Zinenko 
1372d4a53f3bSAlex Zinenko     RewritePatternSet patterns(&getContext());
1373d4a53f3bSAlex Zinenko     patterns.add<ForwardOperandPattern>(converter, &getContext());
1374d4a53f3bSAlex Zinenko 
1375d4a53f3bSAlex Zinenko     if (failed(applyPartialConversion(getOperation(), target,
1376d4a53f3bSAlex Zinenko                                       std::move(patterns))))
1377d4a53f3bSAlex Zinenko       signalPassFailure();
1378d4a53f3bSAlex Zinenko   }
1379d4a53f3bSAlex Zinenko };
1380d4a53f3bSAlex Zinenko } // namespace
1381d4a53f3bSAlex Zinenko 
1382d4a53f3bSAlex Zinenko //===----------------------------------------------------------------------===//
1383c8fb6ee3SRiver Riddle // Test Block Merging
1384c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
1385c8fb6ee3SRiver Riddle 
1386e888886cSMaheshRavishankar namespace {
1387e888886cSMaheshRavishankar /// A rewriter pattern that tests that blocks can be merged.
1388e888886cSMaheshRavishankar struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1389e888886cSMaheshRavishankar   using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1390e888886cSMaheshRavishankar 
1391e888886cSMaheshRavishankar   LogicalResult
matchAndRewrite__anon71cd160d2d11::TestMergeBlock1392ef976337SRiver Riddle   matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1393e888886cSMaheshRavishankar                   ConversionPatternRewriter &rewriter) const final {
13946a994233SJacques Pienaar     Block &firstBlock = op.getBody().front();
1395e888886cSMaheshRavishankar     Operation *branchOp = firstBlock.getTerminator();
13966a994233SJacques Pienaar     Block *secondBlock = &*(std::next(op.getBody().begin()));
1397e888886cSMaheshRavishankar     auto succOperands = branchOp->getOperands();
1398e888886cSMaheshRavishankar     SmallVector<Value, 2> replacements(succOperands);
1399e888886cSMaheshRavishankar     rewriter.eraseOp(branchOp);
1400e888886cSMaheshRavishankar     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1401e888886cSMaheshRavishankar     rewriter.updateRootInPlace(op, [] {});
1402e888886cSMaheshRavishankar     return success();
1403e888886cSMaheshRavishankar   }
1404e888886cSMaheshRavishankar };
1405e888886cSMaheshRavishankar 
1406e888886cSMaheshRavishankar /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1407e888886cSMaheshRavishankar struct TestUndoBlocksMerge : public ConversionPattern {
TestUndoBlocksMerge__anon71cd160d2d11::TestUndoBlocksMerge1408e888886cSMaheshRavishankar   TestUndoBlocksMerge(MLIRContext *ctx)
1409e888886cSMaheshRavishankar       : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1410e888886cSMaheshRavishankar   LogicalResult
matchAndRewrite__anon71cd160d2d11::TestUndoBlocksMerge1411e888886cSMaheshRavishankar   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1412e888886cSMaheshRavishankar                   ConversionPatternRewriter &rewriter) const final {
1413e888886cSMaheshRavishankar     Block &firstBlock = op->getRegion(0).front();
1414e888886cSMaheshRavishankar     Operation *branchOp = firstBlock.getTerminator();
1415e888886cSMaheshRavishankar     Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1416e888886cSMaheshRavishankar     rewriter.setInsertionPointToStart(secondBlock);
1417e888886cSMaheshRavishankar     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1418e888886cSMaheshRavishankar     auto succOperands = branchOp->getOperands();
1419e888886cSMaheshRavishankar     SmallVector<Value, 2> replacements(succOperands);
1420e888886cSMaheshRavishankar     rewriter.eraseOp(branchOp);
1421e888886cSMaheshRavishankar     rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1422e888886cSMaheshRavishankar     rewriter.updateRootInPlace(op, [] {});
1423e888886cSMaheshRavishankar     return success();
1424e888886cSMaheshRavishankar   }
1425e888886cSMaheshRavishankar };
1426e888886cSMaheshRavishankar 
1427e888886cSMaheshRavishankar /// A rewrite mechanism to inline the body of the op into its parent, when both
1428e888886cSMaheshRavishankar /// ops can have a single block.
1429e888886cSMaheshRavishankar struct TestMergeSingleBlockOps
1430e888886cSMaheshRavishankar     : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1431e888886cSMaheshRavishankar   using OpConversionPattern<
1432e888886cSMaheshRavishankar       SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1433e888886cSMaheshRavishankar 
1434e888886cSMaheshRavishankar   LogicalResult
matchAndRewrite__anon71cd160d2d11::TestMergeSingleBlockOps1435ef976337SRiver Riddle   matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1436e888886cSMaheshRavishankar                   ConversionPatternRewriter &rewriter) const final {
1437e888886cSMaheshRavishankar     SingleBlockImplicitTerminatorOp parentOp =
14380bf4a82aSChristian Sigg         op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1439e888886cSMaheshRavishankar     if (!parentOp)
1440e888886cSMaheshRavishankar       return failure();
14416a994233SJacques Pienaar     Block &innerBlock = op.getRegion().front();
1442e888886cSMaheshRavishankar     TerminatorOp innerTerminator =
1443e888886cSMaheshRavishankar         cast<TerminatorOp>(innerBlock.getTerminator());
14449c7b0c4aSRahul Joshi     rewriter.mergeBlockBefore(&innerBlock, op);
1445e888886cSMaheshRavishankar     rewriter.eraseOp(innerTerminator);
1446e888886cSMaheshRavishankar     rewriter.eraseOp(op);
1447e888886cSMaheshRavishankar     rewriter.updateRootInPlace(op, [] {});
1448e888886cSMaheshRavishankar     return success();
1449e888886cSMaheshRavishankar   }
1450e888886cSMaheshRavishankar };
1451e888886cSMaheshRavishankar 
1452e888886cSMaheshRavishankar struct TestMergeBlocksPatternDriver
1453e888886cSMaheshRavishankar     : public PassWrapper<TestMergeBlocksPatternDriver,
1454e888886cSMaheshRavishankar                          OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d2d11::TestMergeBlocksPatternDriver14555e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
14565e50dd04SRiver Riddle 
1457b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-merge-blocks"; }
getDescription__anon71cd160d2d11::TestMergeBlocksPatternDriver1458b5e22e6dSMehdi Amini   StringRef getDescription() const final {
1459b5e22e6dSMehdi Amini     return "Test Merging operation in ConversionPatternRewriter";
1460b5e22e6dSMehdi Amini   }
runOnOperation__anon71cd160d2d11::TestMergeBlocksPatternDriver1461e888886cSMaheshRavishankar   void runOnOperation() override {
1462e888886cSMaheshRavishankar     MLIRContext *context = &getContext();
1463dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(context);
1464dc4e913bSChris Lattner     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1465e888886cSMaheshRavishankar         context);
1466e888886cSMaheshRavishankar     ConversionTarget target(*context);
146758ceae95SRiver Riddle     target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1468973ddb7dSMehdi Amini                       TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1469e888886cSMaheshRavishankar     target.addIllegalOp<ILLegalOpF>();
1470e888886cSMaheshRavishankar 
1471e888886cSMaheshRavishankar     /// Expect the op to have a single block after legalization.
1472e888886cSMaheshRavishankar     target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1473e888886cSMaheshRavishankar         [&](TestMergeBlocksOp op) -> bool {
14746a994233SJacques Pienaar           return llvm::hasSingleElement(op.getBody());
1475e888886cSMaheshRavishankar         });
1476e888886cSMaheshRavishankar 
1477e888886cSMaheshRavishankar     /// Only allow `test.br` within test.merge_blocks op.
1478e888886cSMaheshRavishankar     target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
14790bf4a82aSChristian Sigg       return op->getParentOfType<TestMergeBlocksOp>();
1480e888886cSMaheshRavishankar     });
1481e888886cSMaheshRavishankar 
1482e888886cSMaheshRavishankar     /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1483e888886cSMaheshRavishankar     /// inlined.
1484e888886cSMaheshRavishankar     target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1485e888886cSMaheshRavishankar         [&](SingleBlockImplicitTerminatorOp op) -> bool {
14860bf4a82aSChristian Sigg           return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1487e888886cSMaheshRavishankar         });
1488e888886cSMaheshRavishankar 
1489e888886cSMaheshRavishankar     DenseSet<Operation *> unlegalizedOps;
14903fffffa8SRiver Riddle     (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1491e888886cSMaheshRavishankar                                  &unlegalizedOps);
1492e888886cSMaheshRavishankar     for (auto *op : unlegalizedOps)
1493e888886cSMaheshRavishankar       op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1494e888886cSMaheshRavishankar   }
1495e888886cSMaheshRavishankar };
1496e888886cSMaheshRavishankar } // namespace
1497e888886cSMaheshRavishankar 
14984589dd92SRiver Riddle //===----------------------------------------------------------------------===//
1499c8fb6ee3SRiver Riddle // Test Selective Replacement
1500c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
1501c8fb6ee3SRiver Riddle 
1502c8fb6ee3SRiver Riddle namespace {
1503c8fb6ee3SRiver Riddle /// A rewrite mechanism to inline the body of the op into its parent, when both
1504c8fb6ee3SRiver Riddle /// ops can have a single block.
1505c8fb6ee3SRiver Riddle struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
1506c8fb6ee3SRiver Riddle   using OpRewritePattern<TestCastOp>::OpRewritePattern;
1507c8fb6ee3SRiver Riddle 
matchAndRewrite__anon71cd160d3411::TestSelectiveOpReplacementPattern1508c8fb6ee3SRiver Riddle   LogicalResult matchAndRewrite(TestCastOp op,
1509c8fb6ee3SRiver Riddle                                 PatternRewriter &rewriter) const final {
1510c8fb6ee3SRiver Riddle     if (op.getNumOperands() != 2)
1511c8fb6ee3SRiver Riddle       return failure();
1512c8fb6ee3SRiver Riddle     OperandRange operands = op.getOperands();
1513c8fb6ee3SRiver Riddle 
1514c8fb6ee3SRiver Riddle     // Replace non-terminator uses with the first operand.
1515c8fb6ee3SRiver Riddle     rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
1516fe7c0d90SRiver Riddle       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
1517c8fb6ee3SRiver Riddle     });
1518c8fb6ee3SRiver Riddle     // Replace everything else with the second operand if the operation isn't
1519c8fb6ee3SRiver Riddle     // dead.
1520c8fb6ee3SRiver Riddle     rewriter.replaceOp(op, op.getOperand(1));
1521c8fb6ee3SRiver Riddle     return success();
1522c8fb6ee3SRiver Riddle   }
1523c8fb6ee3SRiver Riddle };
1524c8fb6ee3SRiver Riddle 
1525c8fb6ee3SRiver Riddle struct TestSelectiveReplacementPatternDriver
1526c8fb6ee3SRiver Riddle     : public PassWrapper<TestSelectiveReplacementPatternDriver,
1527c8fb6ee3SRiver Riddle                          OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d3411::TestSelectiveReplacementPatternDriver15285e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
15295e50dd04SRiver Riddle       TestSelectiveReplacementPatternDriver)
15305e50dd04SRiver Riddle 
1531b5e22e6dSMehdi Amini   StringRef getArgument() const final {
1532b5e22e6dSMehdi Amini     return "test-pattern-selective-replacement";
1533b5e22e6dSMehdi Amini   }
getDescription__anon71cd160d3411::TestSelectiveReplacementPatternDriver1534b5e22e6dSMehdi Amini   StringRef getDescription() const final {
1535b5e22e6dSMehdi Amini     return "Test selective replacement in the PatternRewriter";
1536b5e22e6dSMehdi Amini   }
runOnOperation__anon71cd160d3411::TestSelectiveReplacementPatternDriver1537c8fb6ee3SRiver Riddle   void runOnOperation() override {
1538c8fb6ee3SRiver Riddle     MLIRContext *context = &getContext();
1539dc4e913bSChris Lattner     mlir::RewritePatternSet patterns(context);
1540dc4e913bSChris Lattner     patterns.add<TestSelectiveOpReplacementPattern>(context);
1541e21adfa3SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
1542c8fb6ee3SRiver Riddle                                        std::move(patterns));
1543c8fb6ee3SRiver Riddle   }
1544c8fb6ee3SRiver Riddle };
1545c8fb6ee3SRiver Riddle } // namespace
1546c8fb6ee3SRiver Riddle 
1547c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
15484589dd92SRiver Riddle // PassRegistration
15494589dd92SRiver Riddle //===----------------------------------------------------------------------===//
15504589dd92SRiver Riddle 
1551fec6c5acSUday Bondhugula namespace mlir {
155272c65b69SAlexander Belyaev namespace test {
registerPatternsTestPass()1553fec6c5acSUday Bondhugula void registerPatternsTestPass() {
1554b5e22e6dSMehdi Amini   PassRegistration<TestReturnTypeDriver>();
1555fec6c5acSUday Bondhugula 
1556b5e22e6dSMehdi Amini   PassRegistration<TestDerivedAttributeDriver>();
15579ba37b3bSJacques Pienaar 
1558b5e22e6dSMehdi Amini   PassRegistration<TestPatternDriver>();
1559*ba3a9f51SChia-hung Duan   PassRegistration<TestStrictPatternDriver>();
1560fec6c5acSUday Bondhugula 
1561b5e22e6dSMehdi Amini   PassRegistration<TestLegalizePatternDriver>([] {
1562b5e22e6dSMehdi Amini     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
1563fec6c5acSUday Bondhugula   });
1564fec6c5acSUday Bondhugula 
1565b5e22e6dSMehdi Amini   PassRegistration<TestRemappedValue>();
156680d7ac3bSRiver Riddle 
1567b5e22e6dSMehdi Amini   PassRegistration<TestUnknownRootOpDriver>();
15684589dd92SRiver Riddle 
1569b5e22e6dSMehdi Amini   PassRegistration<TestTypeConversionDriver>();
1570d4a53f3bSAlex Zinenko   PassRegistration<TestTargetMaterializationWithNoUses>();
1571e888886cSMaheshRavishankar 
157288bc24a7SMathieu Fehr   PassRegistration<TestRewriteDynamicOpDriver>();
157388bc24a7SMathieu Fehr 
1574b5e22e6dSMehdi Amini   PassRegistration<TestMergeBlocksPatternDriver>();
1575b5e22e6dSMehdi Amini   PassRegistration<TestSelectiveReplacementPatternDriver>();
1576fec6c5acSUday Bondhugula }
157772c65b69SAlexander Belyaev } // namespace test
1578fec6c5acSUday Bondhugula } // namespace mlir
1579