1fec6c5acSUday Bondhugula //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2fec6c5acSUday Bondhugula //
3fec6c5acSUday Bondhugula // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fec6c5acSUday Bondhugula // See https://llvm.org/LICENSE.txt for license information.
5fec6c5acSUday Bondhugula // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fec6c5acSUday Bondhugula //
7fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
8fec6c5acSUday Bondhugula
9fec6c5acSUday Bondhugula #include "TestDialect.h"
109c5982efSAlex Zinenko #include "TestTypes.h"
11a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1223aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1323aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
14c0a6318dSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
152bf423b0SRob Suderman #include "mlir/IR/Matchers.h"
16fec6c5acSUday Bondhugula #include "mlir/Pass/Pass.h"
17fec6c5acSUday Bondhugula #include "mlir/Transforms/DialectConversion.h"
1826f93d9fSAlex Zinenko #include "mlir/Transforms/FoldUtils.h"
19b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
209ba37b3bSJacques Pienaar
21fec6c5acSUday Bondhugula using namespace mlir;
227776b19eSStephen Neuendorffer using namespace test;
23fec6c5acSUday Bondhugula
24fec6c5acSUday Bondhugula // Native function for testing NativeCodeCall
chooseOperand(Value input1,Value input2,BoolAttr choice)25fec6c5acSUday Bondhugula static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
26fec6c5acSUday Bondhugula return choice.getValue() ? input1 : input2;
27fec6c5acSUday Bondhugula }
28fec6c5acSUday Bondhugula
createOpI(PatternRewriter & rewriter,Location loc,Value input)2929429d1aSJacques Pienaar static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
3029429d1aSJacques Pienaar rewriter.create<OpI>(loc, input);
31fec6c5acSUday Bondhugula }
32fec6c5acSUday Bondhugula
handleNoResultOp(PatternRewriter & rewriter,OpSymbolBindingNoResult op)33fec6c5acSUday Bondhugula static void handleNoResultOp(PatternRewriter &rewriter,
34fec6c5acSUday Bondhugula OpSymbolBindingNoResult op) {
35fec6c5acSUday Bondhugula // Turn the no result op to a one-result op.
366a994233SJacques Pienaar rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
376a994233SJacques Pienaar op.getOperand());
38fec6c5acSUday Bondhugula }
39fec6c5acSUday Bondhugula
getFirstI32Result(Operation * op,Value & value)4034b5482bSChia-hung Duan static bool getFirstI32Result(Operation *op, Value &value) {
4134b5482bSChia-hung Duan if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
4234b5482bSChia-hung Duan return false;
4334b5482bSChia-hung Duan value = op->getResult(0);
4434b5482bSChia-hung Duan return true;
4534b5482bSChia-hung Duan }
4634b5482bSChia-hung Duan
bindNativeCodeCallResult(Value value)4734b5482bSChia-hung Duan static Value bindNativeCodeCallResult(Value value) { return value; }
4834b5482bSChia-hung Duan
bindMultipleNativeCodeCallResult(Value input1,Value input2)49d7314b3cSChia-hung Duan static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
50d7314b3cSChia-hung Duan Value input2) {
51d7314b3cSChia-hung Duan return SmallVector<Value, 2>({input2, input1});
52d7314b3cSChia-hung Duan }
53d7314b3cSChia-hung Duan
5401641197SAlexEichenberger // Test that natives calls are only called once during rewrites.
5501641197SAlexEichenberger // OpM_Test will return Pi, increased by 1 for each subsequent calls.
5601641197SAlexEichenberger // This let us check the number of times OpM_Test was called by inspecting
5701641197SAlexEichenberger // the returned value in the MLIR output.
5801641197SAlexEichenberger static int64_t opMIncreasingValue = 314159265;
opMTest(PatternRewriter & rewriter,Value val)5902b6fb21SMehdi Amini static Attribute opMTest(PatternRewriter &rewriter, Value val) {
6001641197SAlexEichenberger int64_t i = opMIncreasingValue++;
6101641197SAlexEichenberger return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
6201641197SAlexEichenberger }
6301641197SAlexEichenberger
64fec6c5acSUday Bondhugula namespace {
65fec6c5acSUday Bondhugula #include "TestPatterns.inc"
66be0a7e9fSMehdi Amini } // namespace
67fec6c5acSUday Bondhugula
68fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
69c484c7ddSChia-hung Duan // Test Reduce Pattern Interface
70c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
71c484c7ddSChia-hung Duan
populateTestReductionPatterns(RewritePatternSet & patterns)727776b19eSStephen Neuendorffer void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
73c484c7ddSChia-hung Duan populateWithGenerated(patterns);
74c484c7ddSChia-hung Duan }
75c484c7ddSChia-hung Duan
76c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
77fec6c5acSUday Bondhugula // Canonicalizer Driver.
78fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
79fec6c5acSUday Bondhugula
80fec6c5acSUday Bondhugula namespace {
8126f93d9fSAlex Zinenko struct FoldingPattern : public RewritePattern {
8226f93d9fSAlex Zinenko public:
FoldingPattern__anon71cd160d0211::FoldingPattern8326f93d9fSAlex Zinenko FoldingPattern(MLIRContext *context)
8426f93d9fSAlex Zinenko : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
8526f93d9fSAlex Zinenko /*benefit=*/1, context) {}
8626f93d9fSAlex Zinenko
matchAndRewrite__anon71cd160d0211::FoldingPattern8726f93d9fSAlex Zinenko LogicalResult matchAndRewrite(Operation *op,
8826f93d9fSAlex Zinenko PatternRewriter &rewriter) const override {
892b638ed5SKazuaki Ishizaki // Exercise OperationFolder API for a single-result operation that is folded
9026f93d9fSAlex Zinenko // upon construction. The operation being created through the folder has an
9126f93d9fSAlex Zinenko // in-place folder, and it should be still present in the output.
9226f93d9fSAlex Zinenko // Furthermore, the folder should not crash when attempting to recover the
93a23d0559SKazuaki Ishizaki // (unchanged) operation result.
9426f93d9fSAlex Zinenko OperationFolder folder(op->getContext());
9526f93d9fSAlex Zinenko Value result = folder.create<TestOpInPlaceFold>(
9626f93d9fSAlex Zinenko rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
9726f93d9fSAlex Zinenko rewriter.getI32IntegerAttr(0));
9826f93d9fSAlex Zinenko assert(result);
9926f93d9fSAlex Zinenko rewriter.replaceOp(op, result);
10026f93d9fSAlex Zinenko return success();
10126f93d9fSAlex Zinenko }
10226f93d9fSAlex Zinenko };
10326f93d9fSAlex Zinenko
104e4635e63SRiver Riddle /// This pattern creates a foldable operation at the entry point of the block.
105e4635e63SRiver Riddle /// This tests the situation where the operation folder will need to replace an
106e4635e63SRiver Riddle /// operation with a previously created constant that does not initially
107e4635e63SRiver Riddle /// dominate the operation to replace.
108e4635e63SRiver Riddle struct FolderInsertBeforePreviouslyFoldedConstantPattern
109e4635e63SRiver Riddle : public OpRewritePattern<TestCastOp> {
110e4635e63SRiver Riddle public:
111e4635e63SRiver Riddle using OpRewritePattern<TestCastOp>::OpRewritePattern;
112e4635e63SRiver Riddle
matchAndRewrite__anon71cd160d0211::FolderInsertBeforePreviouslyFoldedConstantPattern113e4635e63SRiver Riddle LogicalResult matchAndRewrite(TestCastOp op,
114e4635e63SRiver Riddle PatternRewriter &rewriter) const override {
115e4635e63SRiver Riddle if (!op->hasAttr("test_fold_before_previously_folded_op"))
116e4635e63SRiver Riddle return failure();
117e4635e63SRiver Riddle rewriter.setInsertionPointToStart(op->getBlock());
118e4635e63SRiver Riddle
119a54f4eaeSMogball auto constOp = rewriter.create<arith::ConstantOp>(
120a54f4eaeSMogball op.getLoc(), rewriter.getBoolAttr(true));
121e4635e63SRiver Riddle rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
122e4635e63SRiver Riddle Value(constOp));
123e4635e63SRiver Riddle return success();
124e4635e63SRiver Riddle }
125e4635e63SRiver Riddle };
126e4635e63SRiver Riddle
12743f0d5f9SMehdi Amini /// This pattern matches test.op_commutative2 with the first operand being
12843f0d5f9SMehdi Amini /// another test.op_commutative2 with a constant on the right side and fold it
12943f0d5f9SMehdi Amini /// away by propagating it as its result. This is intend to check that patterns
13043f0d5f9SMehdi Amini /// are applied after the commutative property moves constant to the right.
13143f0d5f9SMehdi Amini struct FolderCommutativeOp2WithConstant
13243f0d5f9SMehdi Amini : public OpRewritePattern<TestCommutative2Op> {
13343f0d5f9SMehdi Amini public:
13443f0d5f9SMehdi Amini using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
13543f0d5f9SMehdi Amini
matchAndRewrite__anon71cd160d0211::FolderCommutativeOp2WithConstant13643f0d5f9SMehdi Amini LogicalResult matchAndRewrite(TestCommutative2Op op,
13743f0d5f9SMehdi Amini PatternRewriter &rewriter) const override {
13843f0d5f9SMehdi Amini auto operand =
13943f0d5f9SMehdi Amini dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
14043f0d5f9SMehdi Amini if (!operand)
14143f0d5f9SMehdi Amini return failure();
14243f0d5f9SMehdi Amini Attribute constInput;
14343f0d5f9SMehdi Amini if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
14443f0d5f9SMehdi Amini return failure();
14543f0d5f9SMehdi Amini rewriter.replaceOp(op, operand->getOperand(1));
14643f0d5f9SMehdi Amini return success();
14743f0d5f9SMehdi Amini }
14843f0d5f9SMehdi Amini };
14943f0d5f9SMehdi Amini
15041574554SRiver Riddle struct TestPatternDriver
15158ceae95SRiver Riddle : public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> {
1525e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
1535e50dd04SRiver Riddle
1547814b559Srkayaith TestPatternDriver() = default;
TestPatternDriver__anon71cd160d0211::TestPatternDriver1557814b559Srkayaith TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {}
1567814b559Srkayaith
getArgument__anon71cd160d0211::TestPatternDriver157b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-patterns"; }
getDescription__anon71cd160d0211::TestPatternDriver158b5e22e6dSMehdi Amini StringRef getDescription() const final { return "Run test dialect patterns"; }
runOnOperation__anon71cd160d0211::TestPatternDriver15941574554SRiver Riddle void runOnOperation() override {
160dc4e913bSChris Lattner mlir::RewritePatternSet patterns(&getContext());
16183e3c6a7SMehdi Amini populateWithGenerated(patterns);
162fec6c5acSUday Bondhugula
163fec6c5acSUday Bondhugula // Verify named pattern is generated with expected name.
164e4635e63SRiver Riddle patterns.add<FoldingPattern, TestNamedPatternRule,
16543f0d5f9SMehdi Amini FolderInsertBeforePreviouslyFoldedConstantPattern,
16643f0d5f9SMehdi Amini FolderCommutativeOp2WithConstant>(&getContext());
167fec6c5acSUday Bondhugula
1687814b559Srkayaith GreedyRewriteConfig config;
1697814b559Srkayaith config.useTopDownTraversal = this->useTopDownTraversal;
1707814b559Srkayaith (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
1717814b559Srkayaith config);
172fec6c5acSUday Bondhugula }
1737814b559Srkayaith
1747814b559Srkayaith Option<bool> useTopDownTraversal{
1757814b559Srkayaith *this, "top-down",
1767814b559Srkayaith llvm::cl::desc("Seed the worklist in general top-down order"),
1777814b559Srkayaith llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
178fec6c5acSUday Bondhugula };
179*ba3a9f51SChia-hung Duan
180*ba3a9f51SChia-hung Duan struct TestStrictPatternDriver
181*ba3a9f51SChia-hung Duan : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
182*ba3a9f51SChia-hung Duan public:
183*ba3a9f51SChia-hung Duan MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
184*ba3a9f51SChia-hung Duan
185*ba3a9f51SChia-hung Duan TestStrictPatternDriver() = default;
TestStrictPatternDriver__anon71cd160d0211::TestStrictPatternDriver186*ba3a9f51SChia-hung Duan TestStrictPatternDriver(const TestStrictPatternDriver &other)
187*ba3a9f51SChia-hung Duan : PassWrapper(other) {}
188*ba3a9f51SChia-hung Duan
getArgument__anon71cd160d0211::TestStrictPatternDriver189*ba3a9f51SChia-hung Duan StringRef getArgument() const final { return "test-strict-pattern-driver"; }
getDescription__anon71cd160d0211::TestStrictPatternDriver190*ba3a9f51SChia-hung Duan StringRef getDescription() const final {
191*ba3a9f51SChia-hung Duan return "Run strict mode of pattern driver";
192*ba3a9f51SChia-hung Duan }
193*ba3a9f51SChia-hung Duan
runOnOperation__anon71cd160d0211::TestStrictPatternDriver194*ba3a9f51SChia-hung Duan void runOnOperation() override {
195*ba3a9f51SChia-hung Duan mlir::RewritePatternSet patterns(&getContext());
196*ba3a9f51SChia-hung Duan patterns.add<InsertSameOp, ReplaceWithSameOp, EraseOp>(&getContext());
197*ba3a9f51SChia-hung Duan SmallVector<Operation *> ops;
198*ba3a9f51SChia-hung Duan getOperation()->walk([&](Operation *op) {
199*ba3a9f51SChia-hung Duan StringRef opName = op->getName().getStringRef();
200*ba3a9f51SChia-hung Duan if (opName == "test.insert_same_op" ||
201*ba3a9f51SChia-hung Duan opName == "test.replace_with_same_op" || opName == "test.erase_op") {
202*ba3a9f51SChia-hung Duan ops.push_back(op);
203*ba3a9f51SChia-hung Duan }
204*ba3a9f51SChia-hung Duan });
205*ba3a9f51SChia-hung Duan
206*ba3a9f51SChia-hung Duan // Check if these transformations introduce visiting of operations that
207*ba3a9f51SChia-hung Duan // are not in the `ops` set (The new created ops are valid). An invalid
208*ba3a9f51SChia-hung Duan // operation will trigger the assertion while processing.
209*ba3a9f51SChia-hung Duan (void)applyOpPatternsAndFold(makeArrayRef(ops), std::move(patterns),
210*ba3a9f51SChia-hung Duan /*strict=*/true);
211*ba3a9f51SChia-hung Duan }
212*ba3a9f51SChia-hung Duan
213*ba3a9f51SChia-hung Duan private:
214*ba3a9f51SChia-hung Duan // New inserted operation is valid for further transformation.
215*ba3a9f51SChia-hung Duan class InsertSameOp : public RewritePattern {
216*ba3a9f51SChia-hung Duan public:
InsertSameOp(MLIRContext * context)217*ba3a9f51SChia-hung Duan InsertSameOp(MLIRContext *context)
218*ba3a9f51SChia-hung Duan : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
219*ba3a9f51SChia-hung Duan
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const220*ba3a9f51SChia-hung Duan LogicalResult matchAndRewrite(Operation *op,
221*ba3a9f51SChia-hung Duan PatternRewriter &rewriter) const override {
222*ba3a9f51SChia-hung Duan if (op->hasAttr("skip"))
223*ba3a9f51SChia-hung Duan return failure();
224*ba3a9f51SChia-hung Duan
225*ba3a9f51SChia-hung Duan Operation *newOp =
226*ba3a9f51SChia-hung Duan rewriter.create(op->getLoc(), op->getName().getIdentifier(),
227*ba3a9f51SChia-hung Duan op->getOperands(), op->getResultTypes());
228*ba3a9f51SChia-hung Duan op->setAttr("skip", rewriter.getBoolAttr(true));
229*ba3a9f51SChia-hung Duan newOp->setAttr("skip", rewriter.getBoolAttr(true));
230*ba3a9f51SChia-hung Duan
231*ba3a9f51SChia-hung Duan return success();
232*ba3a9f51SChia-hung Duan }
233*ba3a9f51SChia-hung Duan };
234*ba3a9f51SChia-hung Duan
235*ba3a9f51SChia-hung Duan // Replace an operation may introduce the re-visiting of its users.
236*ba3a9f51SChia-hung Duan class ReplaceWithSameOp : public RewritePattern {
237*ba3a9f51SChia-hung Duan public:
ReplaceWithSameOp(MLIRContext * context)238*ba3a9f51SChia-hung Duan ReplaceWithSameOp(MLIRContext *context)
239*ba3a9f51SChia-hung Duan : RewritePattern("test.replace_with_same_op", /*benefit=*/1, context) {}
240*ba3a9f51SChia-hung Duan
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const241*ba3a9f51SChia-hung Duan LogicalResult matchAndRewrite(Operation *op,
242*ba3a9f51SChia-hung Duan PatternRewriter &rewriter) const override {
243*ba3a9f51SChia-hung Duan Operation *newOp =
244*ba3a9f51SChia-hung Duan rewriter.create(op->getLoc(), op->getName().getIdentifier(),
245*ba3a9f51SChia-hung Duan op->getOperands(), op->getResultTypes());
246*ba3a9f51SChia-hung Duan rewriter.replaceOp(op, newOp->getResults());
247*ba3a9f51SChia-hung Duan return success();
248*ba3a9f51SChia-hung Duan }
249*ba3a9f51SChia-hung Duan };
250*ba3a9f51SChia-hung Duan
251*ba3a9f51SChia-hung Duan // Remove an operation may introduce the re-visiting of its opreands.
252*ba3a9f51SChia-hung Duan class EraseOp : public RewritePattern {
253*ba3a9f51SChia-hung Duan public:
EraseOp(MLIRContext * context)254*ba3a9f51SChia-hung Duan EraseOp(MLIRContext *context)
255*ba3a9f51SChia-hung Duan : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const256*ba3a9f51SChia-hung Duan LogicalResult matchAndRewrite(Operation *op,
257*ba3a9f51SChia-hung Duan PatternRewriter &rewriter) const override {
258*ba3a9f51SChia-hung Duan rewriter.eraseOp(op);
259*ba3a9f51SChia-hung Duan return success();
260*ba3a9f51SChia-hung Duan }
261*ba3a9f51SChia-hung Duan };
262*ba3a9f51SChia-hung Duan };
263*ba3a9f51SChia-hung Duan
264be0a7e9fSMehdi Amini } // namespace
265fec6c5acSUday Bondhugula
266fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
267fec6c5acSUday Bondhugula // ReturnType Driver.
268fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
269fec6c5acSUday Bondhugula
270fec6c5acSUday Bondhugula namespace {
271fec6c5acSUday Bondhugula // Generate ops for each instance where the type can be successfully inferred.
272fec6c5acSUday Bondhugula template <typename OpTy>
invokeCreateWithInferredReturnType(Operation * op)273fec6c5acSUday Bondhugula static void invokeCreateWithInferredReturnType(Operation *op) {
274fec6c5acSUday Bondhugula auto *context = op->getContext();
27558ceae95SRiver Riddle auto fop = op->getParentOfType<func::FuncOp>();
276fec6c5acSUday Bondhugula auto location = UnknownLoc::get(context);
277fec6c5acSUday Bondhugula OpBuilder b(op);
278fec6c5acSUday Bondhugula b.setInsertionPointAfter(op);
279fec6c5acSUday Bondhugula
280fec6c5acSUday Bondhugula // Use permutations of 2 args as operands.
281fec6c5acSUday Bondhugula assert(fop.getNumArguments() >= 2);
282fec6c5acSUday Bondhugula for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
283fec6c5acSUday Bondhugula for (int j = 0; j < e; ++j) {
284fec6c5acSUday Bondhugula std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
285fec6c5acSUday Bondhugula SmallVector<Type, 2> inferredReturnTypes;
2865eae715aSJacques Pienaar if (succeeded(OpTy::inferReturnTypes(
2875eae715aSJacques Pienaar context, llvm::None, values, op->getAttrDictionary(),
2885eae715aSJacques Pienaar op->getRegions(), inferredReturnTypes))) {
289fec6c5acSUday Bondhugula OperationState state(location, OpTy::getOperationName());
2909db53a18SRiver Riddle // TODO: Expand to regions.
291bb1d976fSAlex Zinenko OpTy::build(b, state, values, op->getAttrs());
29214ecafd0SChia-hung Duan (void)b.create(state);
293fec6c5acSUday Bondhugula }
294fec6c5acSUday Bondhugula }
295fec6c5acSUday Bondhugula }
296fec6c5acSUday Bondhugula }
297fec6c5acSUday Bondhugula
reifyReturnShape(Operation * op)298fec6c5acSUday Bondhugula static void reifyReturnShape(Operation *op) {
299fec6c5acSUday Bondhugula OpBuilder b(op);
300fec6c5acSUday Bondhugula
301fec6c5acSUday Bondhugula // Use permutations of 2 args as operands.
302fec6c5acSUday Bondhugula auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
303fec6c5acSUday Bondhugula SmallVector<Value, 2> shapes;
304851d02f6SWenyi Zhao if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
3059b051703SMaheshRavishankar !llvm::hasSingleElement(shapes))
306fec6c5acSUday Bondhugula return;
30789de9cc8SMehdi Amini for (const auto &it : llvm::enumerate(shapes)) {
308fec6c5acSUday Bondhugula op->emitRemark() << "value " << it.index() << ": "
309fec6c5acSUday Bondhugula << it.value().getDefiningOp();
310fec6c5acSUday Bondhugula }
3119b051703SMaheshRavishankar }
312fec6c5acSUday Bondhugula
31380aca1eaSRiver Riddle struct TestReturnTypeDriver
31458ceae95SRiver Riddle : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d0411::TestReturnTypeDriver3155e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
3165e50dd04SRiver Riddle
317e2310704SJulian Gross void getDependentDialects(DialectRegistry ®istry) const override {
318c0a6318dSMatthias Springer registry.insert<tensor::TensorDialect>();
319e2310704SJulian Gross }
getArgument__anon71cd160d0411::TestReturnTypeDriver320b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-return-type"; }
getDescription__anon71cd160d0411::TestReturnTypeDriver321b5e22e6dSMehdi Amini StringRef getDescription() const final { return "Run return type functions"; }
322e2310704SJulian Gross
runOnOperation__anon71cd160d0411::TestReturnTypeDriver32341574554SRiver Riddle void runOnOperation() override {
32441574554SRiver Riddle if (getOperation().getName() == "testCreateFunctions") {
325fec6c5acSUday Bondhugula std::vector<Operation *> ops;
326fec6c5acSUday Bondhugula // Collect ops to avoid triggering on inserted ops.
32741574554SRiver Riddle for (auto &op : getOperation().getBody().front())
328fec6c5acSUday Bondhugula ops.push_back(&op);
329fec6c5acSUday Bondhugula // Generate test patterns for each, but skip terminator.
330fec6c5acSUday Bondhugula for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
331fec6c5acSUday Bondhugula // Test create method of each of the Op classes below. The resultant
332fec6c5acSUday Bondhugula // output would be in reverse order underneath `op` from which
333fec6c5acSUday Bondhugula // the attributes and regions are used.
334fec6c5acSUday Bondhugula invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
335fec6c5acSUday Bondhugula invokeCreateWithInferredReturnType<
336fec6c5acSUday Bondhugula OpWithShapedTypeInferTypeInterfaceOp>(op);
337fec6c5acSUday Bondhugula };
338fec6c5acSUday Bondhugula return;
339fec6c5acSUday Bondhugula }
34041574554SRiver Riddle if (getOperation().getName() == "testReifyFunctions") {
341fec6c5acSUday Bondhugula std::vector<Operation *> ops;
342fec6c5acSUday Bondhugula // Collect ops to avoid triggering on inserted ops.
34341574554SRiver Riddle for (auto &op : getOperation().getBody().front())
344fec6c5acSUday Bondhugula if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
345fec6c5acSUday Bondhugula ops.push_back(&op);
346fec6c5acSUday Bondhugula // Generate test patterns for each, but skip terminator.
347fec6c5acSUday Bondhugula for (auto *op : ops)
348fec6c5acSUday Bondhugula reifyReturnShape(op);
349fec6c5acSUday Bondhugula }
350fec6c5acSUday Bondhugula }
351fec6c5acSUday Bondhugula };
352be0a7e9fSMehdi Amini } // namespace
353fec6c5acSUday Bondhugula
3549ba37b3bSJacques Pienaar namespace {
3559ba37b3bSJacques Pienaar struct TestDerivedAttributeDriver
35658ceae95SRiver Riddle : public PassWrapper<TestDerivedAttributeDriver,
35758ceae95SRiver Riddle OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d0511::TestDerivedAttributeDriver3585e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
3595e50dd04SRiver Riddle
360b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-derived-attr"; }
getDescription__anon71cd160d0511::TestDerivedAttributeDriver361b5e22e6dSMehdi Amini StringRef getDescription() const final {
362b5e22e6dSMehdi Amini return "Run test derived attributes";
363b5e22e6dSMehdi Amini }
36441574554SRiver Riddle void runOnOperation() override;
3659ba37b3bSJacques Pienaar };
366be0a7e9fSMehdi Amini } // namespace
3679ba37b3bSJacques Pienaar
runOnOperation()36841574554SRiver Riddle void TestDerivedAttributeDriver::runOnOperation() {
36941574554SRiver Riddle getOperation().walk([](DerivedAttributeOpInterface dOp) {
3709ba37b3bSJacques Pienaar auto dAttr = dOp.materializeDerivedAttributes();
3719ba37b3bSJacques Pienaar if (!dAttr)
3729ba37b3bSJacques Pienaar return;
3739ba37b3bSJacques Pienaar for (auto d : dAttr)
3740c7890c8SRiver Riddle dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
3759ba37b3bSJacques Pienaar });
3769ba37b3bSJacques Pienaar }
3779ba37b3bSJacques Pienaar
378fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
379fec6c5acSUday Bondhugula // Legalization Driver.
380fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
381fec6c5acSUday Bondhugula
382fec6c5acSUday Bondhugula namespace {
383fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
384fec6c5acSUday Bondhugula // Region-Block Rewrite Testing
385fec6c5acSUday Bondhugula
386fec6c5acSUday Bondhugula /// This pattern is a simple pattern that inlines the first region of a given
387fec6c5acSUday Bondhugula /// operation into the parent region.
388fec6c5acSUday Bondhugula struct TestRegionRewriteBlockMovement : public ConversionPattern {
TestRegionRewriteBlockMovement__anon71cd160d0711::TestRegionRewriteBlockMovement389fec6c5acSUday Bondhugula TestRegionRewriteBlockMovement(MLIRContext *ctx)
390fec6c5acSUday Bondhugula : ConversionPattern("test.region", 1, ctx) {}
391fec6c5acSUday Bondhugula
392fec6c5acSUday Bondhugula LogicalResult
matchAndRewrite__anon71cd160d0711::TestRegionRewriteBlockMovement393fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands,
394fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final {
395fec6c5acSUday Bondhugula // Inline this region into the parent region.
396fec6c5acSUday Bondhugula auto &parentRegion = *op->getParentRegion();
397b0750e2dSTres Popp auto &opRegion = op->getRegion(0);
398fec6c5acSUday Bondhugula if (op->getAttr("legalizer.should_clone"))
399b0750e2dSTres Popp rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
400fec6c5acSUday Bondhugula else
401b0750e2dSTres Popp rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
402b0750e2dSTres Popp
403b0750e2dSTres Popp if (op->getAttr("legalizer.erase_old_blocks")) {
404b0750e2dSTres Popp while (!opRegion.empty())
405b0750e2dSTres Popp rewriter.eraseBlock(&opRegion.front());
406b0750e2dSTres Popp }
407fec6c5acSUday Bondhugula
408fec6c5acSUday Bondhugula // Drop this operation.
409fec6c5acSUday Bondhugula rewriter.eraseOp(op);
410fec6c5acSUday Bondhugula return success();
411fec6c5acSUday Bondhugula }
412fec6c5acSUday Bondhugula };
413fec6c5acSUday Bondhugula /// This pattern is a simple pattern that generates a region containing an
414fec6c5acSUday Bondhugula /// illegal operation.
415fec6c5acSUday Bondhugula struct TestRegionRewriteUndo : public RewritePattern {
TestRegionRewriteUndo__anon71cd160d0711::TestRegionRewriteUndo416fec6c5acSUday Bondhugula TestRegionRewriteUndo(MLIRContext *ctx)
417fec6c5acSUday Bondhugula : RewritePattern("test.region_builder", 1, ctx) {}
418fec6c5acSUday Bondhugula
matchAndRewrite__anon71cd160d0711::TestRegionRewriteUndo419fec6c5acSUday Bondhugula LogicalResult matchAndRewrite(Operation *op,
420fec6c5acSUday Bondhugula PatternRewriter &rewriter) const final {
421fec6c5acSUday Bondhugula // Create the region operation with an entry block containing arguments.
422fec6c5acSUday Bondhugula OperationState newRegion(op->getLoc(), "test.region");
423fec6c5acSUday Bondhugula newRegion.addRegion();
42414ecafd0SChia-hung Duan auto *regionOp = rewriter.create(newRegion);
425fec6c5acSUday Bondhugula auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0));
426e084679fSRiver Riddle entryBlock->addArgument(rewriter.getIntegerType(64),
427e084679fSRiver Riddle rewriter.getUnknownLoc());
428fec6c5acSUday Bondhugula
429fec6c5acSUday Bondhugula // Add an explicitly illegal operation to ensure the conversion fails.
430fec6c5acSUday Bondhugula rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
431fec6c5acSUday Bondhugula rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
432fec6c5acSUday Bondhugula
433fec6c5acSUday Bondhugula // Drop this operation.
434fec6c5acSUday Bondhugula rewriter.eraseOp(op);
435fec6c5acSUday Bondhugula return success();
436fec6c5acSUday Bondhugula }
437fec6c5acSUday Bondhugula };
438f27f1e8cSAlex Zinenko /// A simple pattern that creates a block at the end of the parent region of the
439f27f1e8cSAlex Zinenko /// matched operation.
440f27f1e8cSAlex Zinenko struct TestCreateBlock : public RewritePattern {
TestCreateBlock__anon71cd160d0711::TestCreateBlock441f27f1e8cSAlex Zinenko TestCreateBlock(MLIRContext *ctx)
442f27f1e8cSAlex Zinenko : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
443f27f1e8cSAlex Zinenko
matchAndRewrite__anon71cd160d0711::TestCreateBlock444f27f1e8cSAlex Zinenko LogicalResult matchAndRewrite(Operation *op,
445f27f1e8cSAlex Zinenko PatternRewriter &rewriter) const final {
446f27f1e8cSAlex Zinenko Region ®ion = *op->getParentRegion();
447f27f1e8cSAlex Zinenko Type i32Type = rewriter.getIntegerType(32);
448e084679fSRiver Riddle Location loc = op->getLoc();
449e084679fSRiver Riddle rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc});
450e084679fSRiver Riddle rewriter.create<TerminatorOp>(loc);
451f27f1e8cSAlex Zinenko rewriter.replaceOp(op, {});
452f27f1e8cSAlex Zinenko return success();
453f27f1e8cSAlex Zinenko }
454f27f1e8cSAlex Zinenko };
455f27f1e8cSAlex Zinenko
456a23d0559SKazuaki Ishizaki /// A simple pattern that creates a block containing an invalid operation in
457f27f1e8cSAlex Zinenko /// order to trigger the block creation undo mechanism.
458f27f1e8cSAlex Zinenko struct TestCreateIllegalBlock : public RewritePattern {
TestCreateIllegalBlock__anon71cd160d0711::TestCreateIllegalBlock459f27f1e8cSAlex Zinenko TestCreateIllegalBlock(MLIRContext *ctx)
460f27f1e8cSAlex Zinenko : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
461f27f1e8cSAlex Zinenko
matchAndRewrite__anon71cd160d0711::TestCreateIllegalBlock462f27f1e8cSAlex Zinenko LogicalResult matchAndRewrite(Operation *op,
463f27f1e8cSAlex Zinenko PatternRewriter &rewriter) const final {
464f27f1e8cSAlex Zinenko Region ®ion = *op->getParentRegion();
465f27f1e8cSAlex Zinenko Type i32Type = rewriter.getIntegerType(32);
466e084679fSRiver Riddle Location loc = op->getLoc();
467e084679fSRiver Riddle rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc});
468f27f1e8cSAlex Zinenko // Create an illegal op to ensure the conversion fails.
469e084679fSRiver Riddle rewriter.create<ILLegalOpF>(loc, i32Type);
470e084679fSRiver Riddle rewriter.create<TerminatorOp>(loc);
471f27f1e8cSAlex Zinenko rewriter.replaceOp(op, {});
472f27f1e8cSAlex Zinenko return success();
473f27f1e8cSAlex Zinenko }
474f27f1e8cSAlex Zinenko };
475fec6c5acSUday Bondhugula
4760816de16SRiver Riddle /// A simple pattern that tests the undo mechanism when replacing the uses of a
4770816de16SRiver Riddle /// block argument.
4780816de16SRiver Riddle struct TestUndoBlockArgReplace : public ConversionPattern {
TestUndoBlockArgReplace__anon71cd160d0711::TestUndoBlockArgReplace4790816de16SRiver Riddle TestUndoBlockArgReplace(MLIRContext *ctx)
4800816de16SRiver Riddle : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
4810816de16SRiver Riddle
4820816de16SRiver Riddle LogicalResult
matchAndRewrite__anon71cd160d0711::TestUndoBlockArgReplace4830816de16SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands,
4840816de16SRiver Riddle ConversionPatternRewriter &rewriter) const final {
4850816de16SRiver Riddle auto illegalOp =
4860816de16SRiver Riddle rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
487e2b71610SRahul Joshi rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
4880816de16SRiver Riddle illegalOp);
4890816de16SRiver Riddle rewriter.updateRootInPlace(op, [] {});
4900816de16SRiver Riddle return success();
4910816de16SRiver Riddle }
4920816de16SRiver Riddle };
4930816de16SRiver Riddle
494df48026bSAlex Zinenko /// A rewrite pattern that tests the undo mechanism when erasing a block.
495df48026bSAlex Zinenko struct TestUndoBlockErase : public ConversionPattern {
TestUndoBlockErase__anon71cd160d0711::TestUndoBlockErase496df48026bSAlex Zinenko TestUndoBlockErase(MLIRContext *ctx)
497df48026bSAlex Zinenko : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
498df48026bSAlex Zinenko
499df48026bSAlex Zinenko LogicalResult
matchAndRewrite__anon71cd160d0711::TestUndoBlockErase500df48026bSAlex Zinenko matchAndRewrite(Operation *op, ArrayRef<Value> operands,
501df48026bSAlex Zinenko ConversionPatternRewriter &rewriter) const final {
502df48026bSAlex Zinenko Block *secondBlock = &*std::next(op->getRegion(0).begin());
503df48026bSAlex Zinenko rewriter.setInsertionPointToStart(secondBlock);
504df48026bSAlex Zinenko rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
505df48026bSAlex Zinenko rewriter.eraseBlock(secondBlock);
506df48026bSAlex Zinenko rewriter.updateRootInPlace(op, [] {});
507df48026bSAlex Zinenko return success();
508df48026bSAlex Zinenko }
509df48026bSAlex Zinenko };
510df48026bSAlex Zinenko
511fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
512fec6c5acSUday Bondhugula // Type-Conversion Rewrite Testing
513fec6c5acSUday Bondhugula
514fec6c5acSUday Bondhugula /// This patterns erases a region operation that has had a type conversion.
515fec6c5acSUday Bondhugula struct TestDropOpSignatureConversion : public ConversionPattern {
TestDropOpSignatureConversion__anon71cd160d0711::TestDropOpSignatureConversion516fec6c5acSUday Bondhugula TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
51776f3c2f3SRiver Riddle : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
518fec6c5acSUday Bondhugula LogicalResult
matchAndRewrite__anon71cd160d0711::TestDropOpSignatureConversion519fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands,
520fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const override {
521fec6c5acSUday Bondhugula Region ®ion = op->getRegion(0);
522fec6c5acSUday Bondhugula Block *entry = ®ion.front();
523fec6c5acSUday Bondhugula
524fec6c5acSUday Bondhugula // Convert the original entry arguments.
5258d67d187SRiver Riddle TypeConverter &converter = *getTypeConverter();
526fec6c5acSUday Bondhugula TypeConverter::SignatureConversion result(entry->getNumArguments());
5278d67d187SRiver Riddle if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
5288d67d187SRiver Riddle result)) ||
5298d67d187SRiver Riddle failed(rewriter.convertRegionTypes(®ion, converter, &result)))
530fec6c5acSUday Bondhugula return failure();
531fec6c5acSUday Bondhugula
532fec6c5acSUday Bondhugula // Convert the region signature and just drop the operation.
533fec6c5acSUday Bondhugula rewriter.eraseOp(op);
534fec6c5acSUday Bondhugula return success();
535fec6c5acSUday Bondhugula }
536fec6c5acSUday Bondhugula };
537fec6c5acSUday Bondhugula /// This pattern simply updates the operands of the given operation.
538fec6c5acSUday Bondhugula struct TestPassthroughInvalidOp : public ConversionPattern {
TestPassthroughInvalidOp__anon71cd160d0711::TestPassthroughInvalidOp539fec6c5acSUday Bondhugula TestPassthroughInvalidOp(MLIRContext *ctx)
540fec6c5acSUday Bondhugula : ConversionPattern("test.invalid", 1, ctx) {}
541fec6c5acSUday Bondhugula LogicalResult
matchAndRewrite__anon71cd160d0711::TestPassthroughInvalidOp542fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands,
543fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final {
544fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
545fec6c5acSUday Bondhugula llvm::None);
546fec6c5acSUday Bondhugula return success();
547fec6c5acSUday Bondhugula }
548fec6c5acSUday Bondhugula };
549fec6c5acSUday Bondhugula /// This pattern handles the case of a split return value.
550fec6c5acSUday Bondhugula struct TestSplitReturnType : public ConversionPattern {
TestSplitReturnType__anon71cd160d0711::TestSplitReturnType551fec6c5acSUday Bondhugula TestSplitReturnType(MLIRContext *ctx)
552fec6c5acSUday Bondhugula : ConversionPattern("test.return", 1, ctx) {}
553fec6c5acSUday Bondhugula LogicalResult
matchAndRewrite__anon71cd160d0711::TestSplitReturnType554fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands,
555fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final {
556fec6c5acSUday Bondhugula // Check for a return of F32.
557fec6c5acSUday Bondhugula if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
558fec6c5acSUday Bondhugula return failure();
559fec6c5acSUday Bondhugula
560fec6c5acSUday Bondhugula // Check if the first operation is a cast operation, if it is we use the
561fec6c5acSUday Bondhugula // results directly.
562fec6c5acSUday Bondhugula auto *defOp = operands[0].getDefiningOp();
563015192c6SRiver Riddle if (auto packerOp =
564015192c6SRiver Riddle llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
565fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
566fec6c5acSUday Bondhugula return success();
567fec6c5acSUday Bondhugula }
568fec6c5acSUday Bondhugula
569fec6c5acSUday Bondhugula // Otherwise, fail to match.
570fec6c5acSUday Bondhugula return failure();
571fec6c5acSUday Bondhugula }
572fec6c5acSUday Bondhugula };
573fec6c5acSUday Bondhugula
574fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
575fec6c5acSUday Bondhugula // Multi-Level Type-Conversion Rewrite Testing
576fec6c5acSUday Bondhugula struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
TestChangeProducerTypeI32ToF32__anon71cd160d0711::TestChangeProducerTypeI32ToF32577fec6c5acSUday Bondhugula TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
578fec6c5acSUday Bondhugula : ConversionPattern("test.type_producer", 1, ctx) {}
579fec6c5acSUday Bondhugula LogicalResult
matchAndRewrite__anon71cd160d0711::TestChangeProducerTypeI32ToF32580fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands,
581fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final {
582fec6c5acSUday Bondhugula // If the type is I32, change the type to F32.
583fec6c5acSUday Bondhugula if (!Type(*op->result_type_begin()).isSignlessInteger(32))
584fec6c5acSUday Bondhugula return failure();
585fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
586fec6c5acSUday Bondhugula return success();
587fec6c5acSUday Bondhugula }
588fec6c5acSUday Bondhugula };
589fec6c5acSUday Bondhugula struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
TestChangeProducerTypeF32ToF64__anon71cd160d0711::TestChangeProducerTypeF32ToF64590fec6c5acSUday Bondhugula TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
591fec6c5acSUday Bondhugula : ConversionPattern("test.type_producer", 1, ctx) {}
592fec6c5acSUday Bondhugula LogicalResult
matchAndRewrite__anon71cd160d0711::TestChangeProducerTypeF32ToF64593fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands,
594fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final {
595fec6c5acSUday Bondhugula // If the type is F32, change the type to F64.
596fec6c5acSUday Bondhugula if (!Type(*op->result_type_begin()).isF32())
597fec6c5acSUday Bondhugula return rewriter.notifyMatchFailure(op, "expected single f32 operand");
598fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
599fec6c5acSUday Bondhugula return success();
600fec6c5acSUday Bondhugula }
601fec6c5acSUday Bondhugula };
602fec6c5acSUday Bondhugula struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
TestChangeProducerTypeF32ToInvalid__anon71cd160d0711::TestChangeProducerTypeF32ToInvalid603fec6c5acSUday Bondhugula TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
604fec6c5acSUday Bondhugula : ConversionPattern("test.type_producer", 10, ctx) {}
605fec6c5acSUday Bondhugula LogicalResult
matchAndRewrite__anon71cd160d0711::TestChangeProducerTypeF32ToInvalid606fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands,
607fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final {
608fec6c5acSUday Bondhugula // Always convert to B16, even though it is not a legal type. This tests
609fec6c5acSUday Bondhugula // that values are unmapped correctly.
610fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
611fec6c5acSUday Bondhugula return success();
612fec6c5acSUday Bondhugula }
613fec6c5acSUday Bondhugula };
614fec6c5acSUday Bondhugula struct TestUpdateConsumerType : public ConversionPattern {
TestUpdateConsumerType__anon71cd160d0711::TestUpdateConsumerType615fec6c5acSUday Bondhugula TestUpdateConsumerType(MLIRContext *ctx)
616fec6c5acSUday Bondhugula : ConversionPattern("test.type_consumer", 1, ctx) {}
617fec6c5acSUday Bondhugula LogicalResult
matchAndRewrite__anon71cd160d0711::TestUpdateConsumerType618fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands,
619fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final {
620fec6c5acSUday Bondhugula // Verify that the incoming operand has been successfully remapped to F64.
621fec6c5acSUday Bondhugula if (!operands[0].getType().isF64())
622fec6c5acSUday Bondhugula return failure();
623fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
624fec6c5acSUday Bondhugula return success();
625fec6c5acSUday Bondhugula }
626fec6c5acSUday Bondhugula };
627fec6c5acSUday Bondhugula
628fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
629fec6c5acSUday Bondhugula // Non-Root Replacement Rewrite Testing
630fec6c5acSUday Bondhugula /// This pattern generates an invalid operation, but replaces it before the
631fec6c5acSUday Bondhugula /// pattern is finished. This checks that we don't need to legalize the
632fec6c5acSUday Bondhugula /// temporary op.
633fec6c5acSUday Bondhugula struct TestNonRootReplacement : public RewritePattern {
TestNonRootReplacement__anon71cd160d0711::TestNonRootReplacement634fec6c5acSUday Bondhugula TestNonRootReplacement(MLIRContext *ctx)
635fec6c5acSUday Bondhugula : RewritePattern("test.replace_non_root", 1, ctx) {}
636fec6c5acSUday Bondhugula
matchAndRewrite__anon71cd160d0711::TestNonRootReplacement637fec6c5acSUday Bondhugula LogicalResult matchAndRewrite(Operation *op,
638fec6c5acSUday Bondhugula PatternRewriter &rewriter) const final {
639fec6c5acSUday Bondhugula auto resultType = *op->result_type_begin();
640fec6c5acSUday Bondhugula auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
641fec6c5acSUday Bondhugula auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
642fec6c5acSUday Bondhugula
643fec6c5acSUday Bondhugula rewriter.replaceOp(illegalOp, {legalOp});
644fec6c5acSUday Bondhugula rewriter.replaceOp(op, {illegalOp});
645fec6c5acSUday Bondhugula return success();
646fec6c5acSUday Bondhugula }
647fec6c5acSUday Bondhugula };
648bd1ccfe6SRiver Riddle
649bd1ccfe6SRiver Riddle //===----------------------------------------------------------------------===//
650bd1ccfe6SRiver Riddle // Recursive Rewrite Testing
651bd1ccfe6SRiver Riddle /// This pattern is applied to the same operation multiple times, but has a
652bd1ccfe6SRiver Riddle /// bounded recursion.
653bd1ccfe6SRiver Riddle struct TestBoundedRecursiveRewrite
654bd1ccfe6SRiver Riddle : public OpRewritePattern<TestRecursiveRewriteOp> {
6552257e4a7SRiver Riddle using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
6562257e4a7SRiver Riddle
initialize__anon71cd160d0711::TestBoundedRecursiveRewrite6572257e4a7SRiver Riddle void initialize() {
658b99bd771SRiver Riddle // The conversion target handles bounding the recursion of this pattern.
659b99bd771SRiver Riddle setHasBoundedRewriteRecursion();
660b99bd771SRiver Riddle }
661bd1ccfe6SRiver Riddle
matchAndRewrite__anon71cd160d0711::TestBoundedRecursiveRewrite662bd1ccfe6SRiver Riddle LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
663bd1ccfe6SRiver Riddle PatternRewriter &rewriter) const final {
664bd1ccfe6SRiver Riddle // Decrement the depth of the op in-place.
665bd1ccfe6SRiver Riddle rewriter.updateRootInPlace(op, [&] {
6666a994233SJacques Pienaar op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
667bd1ccfe6SRiver Riddle });
668bd1ccfe6SRiver Riddle return success();
669bd1ccfe6SRiver Riddle }
670bd1ccfe6SRiver Riddle };
6715d5df06aSAlex Zinenko
6725d5df06aSAlex Zinenko struct TestNestedOpCreationUndoRewrite
6735d5df06aSAlex Zinenko : public OpRewritePattern<IllegalOpWithRegionAnchor> {
6745d5df06aSAlex Zinenko using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
6755d5df06aSAlex Zinenko
matchAndRewrite__anon71cd160d0711::TestNestedOpCreationUndoRewrite6765d5df06aSAlex Zinenko LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
6775d5df06aSAlex Zinenko PatternRewriter &rewriter) const final {
6785d5df06aSAlex Zinenko // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
6795d5df06aSAlex Zinenko rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
6805d5df06aSAlex Zinenko return success();
6815d5df06aSAlex Zinenko };
6825d5df06aSAlex Zinenko };
683a360a978SMehdi Amini
684a360a978SMehdi Amini // This pattern matches `test.blackhole` and delete this op and its producer.
685a360a978SMehdi Amini struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
686a360a978SMehdi Amini using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
687a360a978SMehdi Amini
matchAndRewrite__anon71cd160d0711::TestReplaceEraseOp688a360a978SMehdi Amini LogicalResult matchAndRewrite(BlackHoleOp op,
689a360a978SMehdi Amini PatternRewriter &rewriter) const final {
690a360a978SMehdi Amini Operation *producer = op.getOperand().getDefiningOp();
691a360a978SMehdi Amini // Always erase the user before the producer, the framework should handle
692a360a978SMehdi Amini // this correctly.
693a360a978SMehdi Amini rewriter.eraseOp(op);
694a360a978SMehdi Amini rewriter.eraseOp(producer);
695a360a978SMehdi Amini return success();
696a360a978SMehdi Amini };
697a360a978SMehdi Amini };
698ec03bbe8SVladislav Vinogradov
699ec03bbe8SVladislav Vinogradov // This pattern replaces explicitly illegal op with explicitly legal op,
700ec03bbe8SVladislav Vinogradov // but in addition creates unregistered operation.
701ec03bbe8SVladislav Vinogradov struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
702ec03bbe8SVladislav Vinogradov using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
703ec03bbe8SVladislav Vinogradov
matchAndRewrite__anon71cd160d0711::TestCreateUnregisteredOp704ec03bbe8SVladislav Vinogradov LogicalResult matchAndRewrite(ILLegalOpG op,
705ec03bbe8SVladislav Vinogradov PatternRewriter &rewriter) const final {
706ec03bbe8SVladislav Vinogradov IntegerAttr attr = rewriter.getI32IntegerAttr(0);
7078e123ca6SRiver Riddle Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
708ec03bbe8SVladislav Vinogradov rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
709ec03bbe8SVladislav Vinogradov return success();
710ec03bbe8SVladislav Vinogradov };
711ec03bbe8SVladislav Vinogradov };
712fec6c5acSUday Bondhugula } // namespace
713fec6c5acSUday Bondhugula
714fec6c5acSUday Bondhugula namespace {
715fec6c5acSUday Bondhugula struct TestTypeConverter : public TypeConverter {
716fec6c5acSUday Bondhugula using TypeConverter::TypeConverter;
TestTypeConverter__anon71cd160d0b11::TestTypeConverter7175c5dafc5SAlex Zinenko TestTypeConverter() {
7185c5dafc5SAlex Zinenko addConversion(convertType);
7194589dd92SRiver Riddle addArgumentMaterialization(materializeCast);
7204589dd92SRiver Riddle addSourceMaterialization(materializeCast);
7215c5dafc5SAlex Zinenko }
722fec6c5acSUday Bondhugula
convertType__anon71cd160d0b11::TestTypeConverter723fec6c5acSUday Bondhugula static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
724fec6c5acSUday Bondhugula // Drop I16 types.
725fec6c5acSUday Bondhugula if (t.isSignlessInteger(16))
726fec6c5acSUday Bondhugula return success();
727fec6c5acSUday Bondhugula
728fec6c5acSUday Bondhugula // Convert I64 to F64.
729fec6c5acSUday Bondhugula if (t.isSignlessInteger(64)) {
730fec6c5acSUday Bondhugula results.push_back(FloatType::getF64(t.getContext()));
731fec6c5acSUday Bondhugula return success();
732fec6c5acSUday Bondhugula }
733fec6c5acSUday Bondhugula
7345c5dafc5SAlex Zinenko // Convert I42 to I43.
7355c5dafc5SAlex Zinenko if (t.isInteger(42)) {
7361b97cdf8SRiver Riddle results.push_back(IntegerType::get(t.getContext(), 43));
7375c5dafc5SAlex Zinenko return success();
7385c5dafc5SAlex Zinenko }
7395c5dafc5SAlex Zinenko
740fec6c5acSUday Bondhugula // Split F32 into F16,F16.
741fec6c5acSUday Bondhugula if (t.isF32()) {
742fec6c5acSUday Bondhugula results.assign(2, FloatType::getF16(t.getContext()));
743fec6c5acSUday Bondhugula return success();
744fec6c5acSUday Bondhugula }
745fec6c5acSUday Bondhugula
746fec6c5acSUday Bondhugula // Otherwise, convert the type directly.
747fec6c5acSUday Bondhugula results.push_back(t);
748fec6c5acSUday Bondhugula return success();
749fec6c5acSUday Bondhugula }
750fec6c5acSUday Bondhugula
7515c5dafc5SAlex Zinenko /// Hook for materializing a conversion. This is necessary because we generate
7525c5dafc5SAlex Zinenko /// 1->N type mappings.
materializeCast__anon71cd160d0b11::TestTypeConverter7534589dd92SRiver Riddle static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
7544589dd92SRiver Riddle ValueRange inputs, Location loc) {
7554589dd92SRiver Riddle return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
7565c5dafc5SAlex Zinenko }
757fec6c5acSUday Bondhugula };
758fec6c5acSUday Bondhugula
759fec6c5acSUday Bondhugula struct TestLegalizePatternDriver
76080aca1eaSRiver Riddle : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d0b11::TestLegalizePatternDriver7615e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
7625e50dd04SRiver Riddle
763b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-legalize-patterns"; }
getDescription__anon71cd160d0b11::TestLegalizePatternDriver764b5e22e6dSMehdi Amini StringRef getDescription() const final {
765b5e22e6dSMehdi Amini return "Run test dialect legalization patterns";
766b5e22e6dSMehdi Amini }
767fec6c5acSUday Bondhugula /// The mode of conversion to use with the driver.
768fec6c5acSUday Bondhugula enum class ConversionMode { Analysis, Full, Partial };
769fec6c5acSUday Bondhugula
TestLegalizePatternDriver__anon71cd160d0b11::TestLegalizePatternDriver770fec6c5acSUday Bondhugula TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
771fec6c5acSUday Bondhugula
getDependentDialects__anon71cd160d0b11::TestLegalizePatternDriver772ec03bbe8SVladislav Vinogradov void getDependentDialects(DialectRegistry ®istry) const override {
77323aa5a74SRiver Riddle registry.insert<func::FuncDialect>();
774ec03bbe8SVladislav Vinogradov }
775ec03bbe8SVladislav Vinogradov
runOnOperation__anon71cd160d0b11::TestLegalizePatternDriver776722f909fSRiver Riddle void runOnOperation() override {
777fec6c5acSUday Bondhugula TestTypeConverter converter;
778dc4e913bSChris Lattner mlir::RewritePatternSet patterns(&getContext());
7791d909c9aSChris Lattner populateWithGenerated(patterns);
780dc4e913bSChris Lattner patterns
781dc4e913bSChris Lattner .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
782dc4e913bSChris Lattner TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace,
783dc4e913bSChris Lattner TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType,
784df48026bSAlex Zinenko TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
785fec6c5acSUday Bondhugula TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
7865d5df06aSAlex Zinenko TestNonRootReplacement, TestBoundedRecursiveRewrite,
787ec03bbe8SVladislav Vinogradov TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
788ec03bbe8SVladislav Vinogradov TestCreateUnregisteredOp>(&getContext());
789dc4e913bSChris Lattner patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
79058ceae95SRiver Riddle mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
79158ceae95SRiver Riddle patterns, converter);
7923a506b31SChris Lattner mlir::populateCallOpTypeConversionPattern(patterns, converter);
793fec6c5acSUday Bondhugula
794fec6c5acSUday Bondhugula // Define the conversion target used for the test.
795fec6c5acSUday Bondhugula ConversionTarget target(getContext());
796973ddb7dSMehdi Amini target.addLegalOp<ModuleOp>();
797ec03bbe8SVladislav Vinogradov target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
798f27f1e8cSAlex Zinenko TerminatorOp>();
799fec6c5acSUday Bondhugula target
800fec6c5acSUday Bondhugula .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
801fec6c5acSUday Bondhugula target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
802fec6c5acSUday Bondhugula // Don't allow F32 operands.
803fec6c5acSUday Bondhugula return llvm::none_of(op.getOperandTypes(),
804fec6c5acSUday Bondhugula [](Type type) { return type.isF32(); });
805fec6c5acSUday Bondhugula });
80658ceae95SRiver Riddle target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
8074a3460a7SRiver Riddle return converter.isSignatureLegal(op.getFunctionType()) &&
8088d67d187SRiver Riddle converter.isLegal(&op.getBody());
8098d67d187SRiver Riddle });
81023aa5a74SRiver Riddle target.addDynamicallyLegalOp<func::CallOp>(
81123aa5a74SRiver Riddle [&](func::CallOp op) { return converter.isLegal(op); });
812fec6c5acSUday Bondhugula
813a54f4eaeSMogball // TestCreateUnregisteredOp creates `arith.constant` operation,
814ec03bbe8SVladislav Vinogradov // which was not added to target intentionally to test
815ec03bbe8SVladislav Vinogradov // correct error code from conversion driver.
816ec03bbe8SVladislav Vinogradov target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
817ec03bbe8SVladislav Vinogradov
818fec6c5acSUday Bondhugula // Expect the type_producer/type_consumer operations to only operate on f64.
819fec6c5acSUday Bondhugula target.addDynamicallyLegalOp<TestTypeProducerOp>(
820fec6c5acSUday Bondhugula [](TestTypeProducerOp op) { return op.getType().isF64(); });
821fec6c5acSUday Bondhugula target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
822fec6c5acSUday Bondhugula return op.getOperand().getType().isF64();
823fec6c5acSUday Bondhugula });
824fec6c5acSUday Bondhugula
825fec6c5acSUday Bondhugula // Check support for marking certain operations as recursively legal.
82658ceae95SRiver Riddle target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
827fec6c5acSUday Bondhugula return static_cast<bool>(
828fec6c5acSUday Bondhugula op->getAttrOfType<UnitAttr>("test.recursively_legal"));
829fec6c5acSUday Bondhugula });
830fec6c5acSUday Bondhugula
831bd1ccfe6SRiver Riddle // Mark the bound recursion operation as dynamically legal.
832bd1ccfe6SRiver Riddle target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
8336a994233SJacques Pienaar [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
834bd1ccfe6SRiver Riddle
835fec6c5acSUday Bondhugula // Handle a partial conversion.
836fec6c5acSUday Bondhugula if (mode == ConversionMode::Partial) {
8378de482eaSLucy Fox DenseSet<Operation *> unlegalizedOps;
838ec03bbe8SVladislav Vinogradov if (failed(applyPartialConversion(
839ec03bbe8SVladislav Vinogradov getOperation(), target, std::move(patterns), &unlegalizedOps))) {
840ec03bbe8SVladislav Vinogradov getOperation()->emitRemark() << "applyPartialConversion failed";
841ec03bbe8SVladislav Vinogradov }
8428de482eaSLucy Fox // Emit remarks for each legalizable operation.
8438de482eaSLucy Fox for (auto *op : unlegalizedOps)
8448de482eaSLucy Fox op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
845fec6c5acSUday Bondhugula return;
846fec6c5acSUday Bondhugula }
847fec6c5acSUday Bondhugula
848fec6c5acSUday Bondhugula // Handle a full conversion.
849fec6c5acSUday Bondhugula if (mode == ConversionMode::Full) {
850fec6c5acSUday Bondhugula // Check support for marking unknown operations as dynamically legal.
851fec6c5acSUday Bondhugula target.markUnknownOpDynamicallyLegal([](Operation *op) {
852fec6c5acSUday Bondhugula return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
853fec6c5acSUday Bondhugula });
854fec6c5acSUday Bondhugula
855ec03bbe8SVladislav Vinogradov if (failed(applyFullConversion(getOperation(), target,
856ec03bbe8SVladislav Vinogradov std::move(patterns)))) {
857ec03bbe8SVladislav Vinogradov getOperation()->emitRemark() << "applyFullConversion failed";
858ec03bbe8SVladislav Vinogradov }
859fec6c5acSUday Bondhugula return;
860fec6c5acSUday Bondhugula }
861fec6c5acSUday Bondhugula
862fec6c5acSUday Bondhugula // Otherwise, handle an analysis conversion.
863fec6c5acSUday Bondhugula assert(mode == ConversionMode::Analysis);
864fec6c5acSUday Bondhugula
865fec6c5acSUday Bondhugula // Analyze the convertible operations.
866fec6c5acSUday Bondhugula DenseSet<Operation *> legalizedOps;
8673fffffa8SRiver Riddle if (failed(applyAnalysisConversion(getOperation(), target,
8683fffffa8SRiver Riddle std::move(patterns), legalizedOps)))
869fec6c5acSUday Bondhugula return signalPassFailure();
870fec6c5acSUday Bondhugula
871fec6c5acSUday Bondhugula // Emit remarks for each legalizable operation.
872fec6c5acSUday Bondhugula for (auto *op : legalizedOps)
873fec6c5acSUday Bondhugula op->emitRemark() << "op '" << op->getName() << "' is legalizable";
874fec6c5acSUday Bondhugula }
875fec6c5acSUday Bondhugula
876fec6c5acSUday Bondhugula /// The mode of conversion to use.
877fec6c5acSUday Bondhugula ConversionMode mode;
878fec6c5acSUday Bondhugula };
879be0a7e9fSMehdi Amini } // namespace
880fec6c5acSUday Bondhugula
881fec6c5acSUday Bondhugula static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
882fec6c5acSUday Bondhugula legalizerConversionMode(
883fec6c5acSUday Bondhugula "test-legalize-mode",
884fec6c5acSUday Bondhugula llvm::cl::desc("The legalization mode to use with the test driver"),
885fec6c5acSUday Bondhugula llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
886fec6c5acSUday Bondhugula llvm::cl::values(
887fec6c5acSUday Bondhugula clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
888fec6c5acSUday Bondhugula "analysis", "Perform an analysis conversion"),
889fec6c5acSUday Bondhugula clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
890fec6c5acSUday Bondhugula "Perform a full conversion"),
891fec6c5acSUday Bondhugula clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
892fec6c5acSUday Bondhugula "partial", "Perform a partial conversion")));
893fec6c5acSUday Bondhugula
894fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
895fec6c5acSUday Bondhugula // ConversionPatternRewriter::getRemappedValue testing. This method is used
8965aacce3dSKazuaki Ishizaki // to get the remapped value of an original value that was replaced using
897fec6c5acSUday Bondhugula // ConversionPatternRewriter.
898fec6c5acSUday Bondhugula namespace {
899015192c6SRiver Riddle struct TestRemapValueTypeConverter : public TypeConverter {
900015192c6SRiver Riddle using TypeConverter::TypeConverter;
901015192c6SRiver Riddle
TestRemapValueTypeConverter__anon71cd160d1611::TestRemapValueTypeConverter902015192c6SRiver Riddle TestRemapValueTypeConverter() {
903015192c6SRiver Riddle addConversion(
904015192c6SRiver Riddle [](Float32Type type) { return Float64Type::get(type.getContext()); });
905015192c6SRiver Riddle addConversion([](Type type) { return type; });
906015192c6SRiver Riddle }
907015192c6SRiver Riddle };
908015192c6SRiver Riddle
909fec6c5acSUday Bondhugula /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
910fec6c5acSUday Bondhugula /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
911fec6c5acSUday Bondhugula /// operand twice.
912fec6c5acSUday Bondhugula ///
913fec6c5acSUday Bondhugula /// Example:
914fec6c5acSUday Bondhugula /// %1 = test.one_variadic_out_one_variadic_in1"(%0)
915fec6c5acSUday Bondhugula /// is replaced with:
916fec6c5acSUday Bondhugula /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
917fec6c5acSUday Bondhugula struct OneVResOneVOperandOp1Converter
918fec6c5acSUday Bondhugula : public OpConversionPattern<OneVResOneVOperandOp1> {
919fec6c5acSUday Bondhugula using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
920fec6c5acSUday Bondhugula
921fec6c5acSUday Bondhugula LogicalResult
matchAndRewrite__anon71cd160d1611::OneVResOneVOperandOp1Converter922ef976337SRiver Riddle matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
923fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const override {
924fec6c5acSUday Bondhugula auto origOps = op.getOperands();
925fec6c5acSUday Bondhugula assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
926fec6c5acSUday Bondhugula "One operand expected");
927fec6c5acSUday Bondhugula Value origOp = *origOps.begin();
928fec6c5acSUday Bondhugula SmallVector<Value, 2> remappedOperands;
929fec6c5acSUday Bondhugula // Replicate the remapped original operand twice. Note that we don't used
930fec6c5acSUday Bondhugula // the remapped 'operand' since the goal is testing 'getRemappedValue'.
931fec6c5acSUday Bondhugula remappedOperands.push_back(rewriter.getRemappedValue(origOp));
932fec6c5acSUday Bondhugula remappedOperands.push_back(rewriter.getRemappedValue(origOp));
933fec6c5acSUday Bondhugula
934fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
935fec6c5acSUday Bondhugula remappedOperands);
936fec6c5acSUday Bondhugula return success();
937fec6c5acSUday Bondhugula }
938fec6c5acSUday Bondhugula };
939fec6c5acSUday Bondhugula
940015192c6SRiver Riddle /// A rewriter pattern that tests that blocks can be merged.
941015192c6SRiver Riddle struct TestRemapValueInRegion
942015192c6SRiver Riddle : public OpConversionPattern<TestRemappedValueRegionOp> {
943015192c6SRiver Riddle using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
944015192c6SRiver Riddle
945015192c6SRiver Riddle LogicalResult
matchAndRewrite__anon71cd160d1611::TestRemapValueInRegion946015192c6SRiver Riddle matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
947015192c6SRiver Riddle ConversionPatternRewriter &rewriter) const final {
948015192c6SRiver Riddle Block &block = op.getBody().front();
949015192c6SRiver Riddle Operation *terminator = block.getTerminator();
950015192c6SRiver Riddle
951015192c6SRiver Riddle // Merge the block into the parent region.
952015192c6SRiver Riddle Block *parentBlock = op->getBlock();
953015192c6SRiver Riddle Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
954015192c6SRiver Riddle rewriter.mergeBlocks(&block, parentBlock, ValueRange());
955015192c6SRiver Riddle rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
956015192c6SRiver Riddle
957015192c6SRiver Riddle // Replace the results of this operation with the remapped terminator
958015192c6SRiver Riddle // values.
959015192c6SRiver Riddle SmallVector<Value> terminatorOperands;
960015192c6SRiver Riddle if (failed(rewriter.getRemappedValues(terminator->getOperands(),
961015192c6SRiver Riddle terminatorOperands)))
962015192c6SRiver Riddle return failure();
963015192c6SRiver Riddle
964015192c6SRiver Riddle rewriter.eraseOp(terminator);
965015192c6SRiver Riddle rewriter.replaceOp(op, terminatorOperands);
966015192c6SRiver Riddle return success();
967015192c6SRiver Riddle }
968015192c6SRiver Riddle };
969015192c6SRiver Riddle
97080aca1eaSRiver Riddle struct TestRemappedValue
97158ceae95SRiver Riddle : public mlir::PassWrapper<TestRemappedValue, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d1611::TestRemappedValue9725e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
9735e50dd04SRiver Riddle
974b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-remapped-value"; }
getDescription__anon71cd160d1611::TestRemappedValue975b5e22e6dSMehdi Amini StringRef getDescription() const final {
976b5e22e6dSMehdi Amini return "Test public remapped value mechanism in ConversionPatternRewriter";
977b5e22e6dSMehdi Amini }
runOnOperation__anon71cd160d1611::TestRemappedValue97841574554SRiver Riddle void runOnOperation() override {
979015192c6SRiver Riddle TestRemapValueTypeConverter typeConverter;
980015192c6SRiver Riddle
981dc4e913bSChris Lattner mlir::RewritePatternSet patterns(&getContext());
982dc4e913bSChris Lattner patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
983015192c6SRiver Riddle patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
984015192c6SRiver Riddle &getContext());
985015192c6SRiver Riddle patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
986fec6c5acSUday Bondhugula
987fec6c5acSUday Bondhugula mlir::ConversionTarget target(getContext());
98858ceae95SRiver Riddle target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
989015192c6SRiver Riddle
990015192c6SRiver Riddle // Expect the type_producer/type_consumer operations to only operate on f64.
991015192c6SRiver Riddle target.addDynamicallyLegalOp<TestTypeProducerOp>(
992015192c6SRiver Riddle [](TestTypeProducerOp op) { return op.getType().isF64(); });
993015192c6SRiver Riddle target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
994015192c6SRiver Riddle return op.getOperand().getType().isF64();
995015192c6SRiver Riddle });
996015192c6SRiver Riddle
997fec6c5acSUday Bondhugula // We make OneVResOneVOperandOp1 legal only when it has more that one
998fec6c5acSUday Bondhugula // operand. This will trigger the conversion that will replace one-operand
999fec6c5acSUday Bondhugula // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1000fec6c5acSUday Bondhugula target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1001015192c6SRiver Riddle [](Operation *op) { return op->getNumOperands() > 1; });
1002fec6c5acSUday Bondhugula
100341574554SRiver Riddle if (failed(mlir::applyFullConversion(getOperation(), target,
10043fffffa8SRiver Riddle std::move(patterns)))) {
1005fec6c5acSUday Bondhugula signalPassFailure();
1006fec6c5acSUday Bondhugula }
1007fec6c5acSUday Bondhugula }
1008fec6c5acSUday Bondhugula };
1009be0a7e9fSMehdi Amini } // namespace
1010fec6c5acSUday Bondhugula
101180d7ac3bSRiver Riddle //===----------------------------------------------------------------------===//
101280d7ac3bSRiver Riddle // Test patterns without a specific root operation kind
101380d7ac3bSRiver Riddle //===----------------------------------------------------------------------===//
101480d7ac3bSRiver Riddle
101580d7ac3bSRiver Riddle namespace {
101680d7ac3bSRiver Riddle /// This pattern matches and removes any operation in the test dialect.
101780d7ac3bSRiver Riddle struct RemoveTestDialectOps : public RewritePattern {
RemoveTestDialectOps__anon71cd160d1c11::RemoveTestDialectOps101876f3c2f3SRiver Riddle RemoveTestDialectOps(MLIRContext *context)
101976f3c2f3SRiver Riddle : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
102080d7ac3bSRiver Riddle
matchAndRewrite__anon71cd160d1c11::RemoveTestDialectOps102180d7ac3bSRiver Riddle LogicalResult matchAndRewrite(Operation *op,
102280d7ac3bSRiver Riddle PatternRewriter &rewriter) const override {
102380d7ac3bSRiver Riddle if (!isa<TestDialect>(op->getDialect()))
102480d7ac3bSRiver Riddle return failure();
102580d7ac3bSRiver Riddle rewriter.eraseOp(op);
102680d7ac3bSRiver Riddle return success();
102780d7ac3bSRiver Riddle }
102880d7ac3bSRiver Riddle };
102980d7ac3bSRiver Riddle
103080d7ac3bSRiver Riddle struct TestUnknownRootOpDriver
103158ceae95SRiver Riddle : public mlir::PassWrapper<TestUnknownRootOpDriver,
103258ceae95SRiver Riddle OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d1c11::TestUnknownRootOpDriver10335e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
10345e50dd04SRiver Riddle
1035b5e22e6dSMehdi Amini StringRef getArgument() const final {
1036b5e22e6dSMehdi Amini return "test-legalize-unknown-root-patterns";
1037b5e22e6dSMehdi Amini }
getDescription__anon71cd160d1c11::TestUnknownRootOpDriver1038b5e22e6dSMehdi Amini StringRef getDescription() const final {
1039b5e22e6dSMehdi Amini return "Test public remapped value mechanism in ConversionPatternRewriter";
1040b5e22e6dSMehdi Amini }
runOnOperation__anon71cd160d1c11::TestUnknownRootOpDriver104141574554SRiver Riddle void runOnOperation() override {
1042dc4e913bSChris Lattner mlir::RewritePatternSet patterns(&getContext());
104376f3c2f3SRiver Riddle patterns.add<RemoveTestDialectOps>(&getContext());
104480d7ac3bSRiver Riddle
104580d7ac3bSRiver Riddle mlir::ConversionTarget target(getContext());
104680d7ac3bSRiver Riddle target.addIllegalDialect<TestDialect>();
104741574554SRiver Riddle if (failed(applyPartialConversion(getOperation(), target,
104841574554SRiver Riddle std::move(patterns))))
104980d7ac3bSRiver Riddle signalPassFailure();
105080d7ac3bSRiver Riddle }
105180d7ac3bSRiver Riddle };
1052be0a7e9fSMehdi Amini } // namespace
105380d7ac3bSRiver Riddle
10544589dd92SRiver Riddle //===----------------------------------------------------------------------===//
105588bc24a7SMathieu Fehr // Test patterns that uses operations and types defined at runtime
105688bc24a7SMathieu Fehr //===----------------------------------------------------------------------===//
105788bc24a7SMathieu Fehr
105888bc24a7SMathieu Fehr namespace {
105988bc24a7SMathieu Fehr /// This pattern matches dynamic operations 'test.one_operand_two_results' and
106088bc24a7SMathieu Fehr /// replace them with dynamic operations 'test.generic_dynamic_op'.
106188bc24a7SMathieu Fehr struct RewriteDynamicOp : public RewritePattern {
RewriteDynamicOp__anon71cd160d1d11::RewriteDynamicOp106288bc24a7SMathieu Fehr RewriteDynamicOp(MLIRContext *context)
106388bc24a7SMathieu Fehr : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
106488bc24a7SMathieu Fehr context) {}
106588bc24a7SMathieu Fehr
matchAndRewrite__anon71cd160d1d11::RewriteDynamicOp106688bc24a7SMathieu Fehr LogicalResult matchAndRewrite(Operation *op,
106788bc24a7SMathieu Fehr PatternRewriter &rewriter) const override {
106888bc24a7SMathieu Fehr assert(op->getName().getStringRef() ==
106988bc24a7SMathieu Fehr "test.dynamic_one_operand_two_results" &&
107088bc24a7SMathieu Fehr "rewrite pattern should only match operations with the right name");
107188bc24a7SMathieu Fehr
107288bc24a7SMathieu Fehr OperationState state(op->getLoc(), "test.dynamic_generic",
107388bc24a7SMathieu Fehr op->getOperands(), op->getResultTypes(),
107488bc24a7SMathieu Fehr op->getAttrs());
107588bc24a7SMathieu Fehr auto *newOp = rewriter.create(state);
107688bc24a7SMathieu Fehr rewriter.replaceOp(op, newOp->getResults());
107788bc24a7SMathieu Fehr return success();
107888bc24a7SMathieu Fehr }
107988bc24a7SMathieu Fehr };
108088bc24a7SMathieu Fehr
108188bc24a7SMathieu Fehr struct TestRewriteDynamicOpDriver
108288bc24a7SMathieu Fehr : public PassWrapper<TestRewriteDynamicOpDriver,
108388bc24a7SMathieu Fehr OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d1d11::TestRewriteDynamicOpDriver108488bc24a7SMathieu Fehr MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
108588bc24a7SMathieu Fehr
108688bc24a7SMathieu Fehr void getDependentDialects(DialectRegistry ®istry) const override {
108788bc24a7SMathieu Fehr registry.insert<TestDialect>();
108888bc24a7SMathieu Fehr }
getArgument__anon71cd160d1d11::TestRewriteDynamicOpDriver108988bc24a7SMathieu Fehr StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
getDescription__anon71cd160d1d11::TestRewriteDynamicOpDriver109088bc24a7SMathieu Fehr StringRef getDescription() const final {
109188bc24a7SMathieu Fehr return "Test rewritting on dynamic operations";
109288bc24a7SMathieu Fehr }
runOnOperation__anon71cd160d1d11::TestRewriteDynamicOpDriver109388bc24a7SMathieu Fehr void runOnOperation() override {
109488bc24a7SMathieu Fehr RewritePatternSet patterns(&getContext());
109588bc24a7SMathieu Fehr patterns.add<RewriteDynamicOp>(&getContext());
109688bc24a7SMathieu Fehr
109788bc24a7SMathieu Fehr ConversionTarget target(getContext());
109888bc24a7SMathieu Fehr target.addIllegalOp(
109988bc24a7SMathieu Fehr OperationName("test.dynamic_one_operand_two_results", &getContext()));
110088bc24a7SMathieu Fehr target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
110188bc24a7SMathieu Fehr if (failed(applyPartialConversion(getOperation(), target,
110288bc24a7SMathieu Fehr std::move(patterns))))
110388bc24a7SMathieu Fehr signalPassFailure();
110488bc24a7SMathieu Fehr }
110588bc24a7SMathieu Fehr };
110688bc24a7SMathieu Fehr } // end anonymous namespace
110788bc24a7SMathieu Fehr
110888bc24a7SMathieu Fehr //===----------------------------------------------------------------------===//
11094589dd92SRiver Riddle // Test type conversions
11104589dd92SRiver Riddle //===----------------------------------------------------------------------===//
11114589dd92SRiver Riddle
11124589dd92SRiver Riddle namespace {
11134589dd92SRiver Riddle struct TestTypeConversionProducer
11144589dd92SRiver Riddle : public OpConversionPattern<TestTypeProducerOp> {
11154589dd92SRiver Riddle using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
11164589dd92SRiver Riddle LogicalResult
matchAndRewrite__anon71cd160d1e11::TestTypeConversionProducer1117ef976337SRiver Riddle matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
11184589dd92SRiver Riddle ConversionPatternRewriter &rewriter) const final {
11194589dd92SRiver Riddle Type resultType = op.getType();
11209c5982efSAlex Zinenko Type convertedType = getTypeConverter()
11219c5982efSAlex Zinenko ? getTypeConverter()->convertType(resultType)
11229c5982efSAlex Zinenko : resultType;
11234589dd92SRiver Riddle if (resultType.isa<FloatType>())
11244589dd92SRiver Riddle resultType = rewriter.getF64Type();
11254589dd92SRiver Riddle else if (resultType.isInteger(16))
11264589dd92SRiver Riddle resultType = rewriter.getIntegerType(64);
11279c5982efSAlex Zinenko else if (resultType.isa<test::TestRecursiveType>() &&
11289c5982efSAlex Zinenko convertedType != resultType)
11299c5982efSAlex Zinenko resultType = convertedType;
11304589dd92SRiver Riddle else
11314589dd92SRiver Riddle return failure();
11324589dd92SRiver Riddle
11334589dd92SRiver Riddle rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
11344589dd92SRiver Riddle return success();
11354589dd92SRiver Riddle }
11364589dd92SRiver Riddle };
11374589dd92SRiver Riddle
11380409eb28SAlex Zinenko /// Call signature conversion and then fail the rewrite to trigger the undo
11390409eb28SAlex Zinenko /// mechanism.
11400409eb28SAlex Zinenko struct TestSignatureConversionUndo
11410409eb28SAlex Zinenko : public OpConversionPattern<TestSignatureConversionUndoOp> {
11420409eb28SAlex Zinenko using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
11430409eb28SAlex Zinenko
11440409eb28SAlex Zinenko LogicalResult
matchAndRewrite__anon71cd160d1e11::TestSignatureConversionUndo1145ef976337SRiver Riddle matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
11460409eb28SAlex Zinenko ConversionPatternRewriter &rewriter) const final {
11470409eb28SAlex Zinenko (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
11480409eb28SAlex Zinenko return failure();
11490409eb28SAlex Zinenko }
11500409eb28SAlex Zinenko };
11510409eb28SAlex Zinenko
11524070f305SRiver Riddle /// Call signature conversion without providing a type converter to handle
11534070f305SRiver Riddle /// materializations.
11544070f305SRiver Riddle struct TestTestSignatureConversionNoConverter
11554070f305SRiver Riddle : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
TestTestSignatureConversionNoConverter__anon71cd160d1e11::TestTestSignatureConversionNoConverter11564070f305SRiver Riddle TestTestSignatureConversionNoConverter(TypeConverter &converter,
11574070f305SRiver Riddle MLIRContext *context)
11584070f305SRiver Riddle : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
11594070f305SRiver Riddle converter(converter) {}
11604070f305SRiver Riddle
11614070f305SRiver Riddle LogicalResult
matchAndRewrite__anon71cd160d1e11::TestTestSignatureConversionNoConverter11624070f305SRiver Riddle matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
11634070f305SRiver Riddle ConversionPatternRewriter &rewriter) const final {
11644070f305SRiver Riddle Region ®ion = op->getRegion(0);
11654070f305SRiver Riddle Block *entry = ®ion.front();
11664070f305SRiver Riddle
11674070f305SRiver Riddle // Convert the original entry arguments.
11684070f305SRiver Riddle TypeConverter::SignatureConversion result(entry->getNumArguments());
11694070f305SRiver Riddle if (failed(
11704070f305SRiver Riddle converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
11714070f305SRiver Riddle return failure();
11724070f305SRiver Riddle rewriter.updateRootInPlace(
11734070f305SRiver Riddle op, [&] { rewriter.applySignatureConversion(®ion, result); });
11744070f305SRiver Riddle return success();
11754070f305SRiver Riddle }
11764070f305SRiver Riddle
11774070f305SRiver Riddle TypeConverter &converter;
11784070f305SRiver Riddle };
11794070f305SRiver Riddle
11800409eb28SAlex Zinenko /// Just forward the operands to the root op. This is essentially a no-op
11810409eb28SAlex Zinenko /// pattern that is used to trigger target materialization.
11820409eb28SAlex Zinenko struct TestTypeConsumerForward
11830409eb28SAlex Zinenko : public OpConversionPattern<TestTypeConsumerOp> {
11840409eb28SAlex Zinenko using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
11850409eb28SAlex Zinenko
11860409eb28SAlex Zinenko LogicalResult
matchAndRewrite__anon71cd160d1e11::TestTypeConsumerForward1187ef976337SRiver Riddle matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
11880409eb28SAlex Zinenko ConversionPatternRewriter &rewriter) const final {
1189ef976337SRiver Riddle rewriter.updateRootInPlace(op,
1190ef976337SRiver Riddle [&] { op->setOperands(adaptor.getOperands()); });
11910409eb28SAlex Zinenko return success();
11920409eb28SAlex Zinenko }
11930409eb28SAlex Zinenko };
11940409eb28SAlex Zinenko
11955b91060dSAlex Zinenko struct TestTypeConversionAnotherProducer
11965b91060dSAlex Zinenko : public OpRewritePattern<TestAnotherTypeProducerOp> {
11975b91060dSAlex Zinenko using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
11985b91060dSAlex Zinenko
matchAndRewrite__anon71cd160d1e11::TestTypeConversionAnotherProducer11995b91060dSAlex Zinenko LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
12005b91060dSAlex Zinenko PatternRewriter &rewriter) const final {
12015b91060dSAlex Zinenko rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
12025b91060dSAlex Zinenko return success();
12035b91060dSAlex Zinenko }
12045b91060dSAlex Zinenko };
12055b91060dSAlex Zinenko
12064589dd92SRiver Riddle struct TestTypeConversionDriver
12074589dd92SRiver Riddle : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d1e11::TestTypeConversionDriver12085e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
12095e50dd04SRiver Riddle
1210f9dc2b70SMehdi Amini void getDependentDialects(DialectRegistry ®istry) const override {
1211f9dc2b70SMehdi Amini registry.insert<TestDialect>();
1212f9dc2b70SMehdi Amini }
getArgument__anon71cd160d1e11::TestTypeConversionDriver1213b5e22e6dSMehdi Amini StringRef getArgument() const final {
1214b5e22e6dSMehdi Amini return "test-legalize-type-conversion";
1215b5e22e6dSMehdi Amini }
getDescription__anon71cd160d1e11::TestTypeConversionDriver1216b5e22e6dSMehdi Amini StringRef getDescription() const final {
1217b5e22e6dSMehdi Amini return "Test various type conversion functionalities in DialectConversion";
1218b5e22e6dSMehdi Amini }
1219f9dc2b70SMehdi Amini
runOnOperation__anon71cd160d1e11::TestTypeConversionDriver12204589dd92SRiver Riddle void runOnOperation() override {
12214589dd92SRiver Riddle // Initialize the type converter.
12224589dd92SRiver Riddle TypeConverter converter;
12234589dd92SRiver Riddle
12244589dd92SRiver Riddle /// Add the legal set of type conversions.
12254589dd92SRiver Riddle converter.addConversion([](Type type) -> Type {
12264589dd92SRiver Riddle // Treat F64 as legal.
12274589dd92SRiver Riddle if (type.isF64())
12284589dd92SRiver Riddle return type;
12294589dd92SRiver Riddle // Allow converting BF16/F16/F32 to F64.
12304589dd92SRiver Riddle if (type.isBF16() || type.isF16() || type.isF32())
12314589dd92SRiver Riddle return FloatType::getF64(type.getContext());
12324589dd92SRiver Riddle // Otherwise, the type is illegal.
12334589dd92SRiver Riddle return nullptr;
12344589dd92SRiver Riddle });
12354589dd92SRiver Riddle converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
12364589dd92SRiver Riddle // Drop all integer types.
12374589dd92SRiver Riddle return success();
12384589dd92SRiver Riddle });
12399c5982efSAlex Zinenko converter.addConversion(
12409c5982efSAlex Zinenko // Convert a recursive self-referring type into a non-self-referring
12419c5982efSAlex Zinenko // type named "outer_converted_type" that contains a SimpleAType.
12429c5982efSAlex Zinenko [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results,
12439c5982efSAlex Zinenko ArrayRef<Type> callStack) -> Optional<LogicalResult> {
12449c5982efSAlex Zinenko // If the type is already converted, return it to indicate that it is
12459c5982efSAlex Zinenko // legal.
12469c5982efSAlex Zinenko if (type.getName() == "outer_converted_type") {
12479c5982efSAlex Zinenko results.push_back(type);
12489c5982efSAlex Zinenko return success();
12499c5982efSAlex Zinenko }
12509c5982efSAlex Zinenko
12519c5982efSAlex Zinenko // If the type is on the call stack more than once (it is there at
12529c5982efSAlex Zinenko // least once because of the _current_ call, which is always the last
12539c5982efSAlex Zinenko // element on the stack), we've hit the recursive case. Just return
12549c5982efSAlex Zinenko // SimpleAType here to create a non-recursive type as a result.
12559c5982efSAlex Zinenko if (llvm::is_contained(callStack.drop_back(), type)) {
12569c5982efSAlex Zinenko results.push_back(test::SimpleAType::get(type.getContext()));
12579c5982efSAlex Zinenko return success();
12589c5982efSAlex Zinenko }
12599c5982efSAlex Zinenko
12609c5982efSAlex Zinenko // Convert the body recursively.
12619c5982efSAlex Zinenko auto result = test::TestRecursiveType::get(type.getContext(),
12629c5982efSAlex Zinenko "outer_converted_type");
12639c5982efSAlex Zinenko if (failed(result.setBody(converter.convertType(type.getBody()))))
12649c5982efSAlex Zinenko return failure();
12659c5982efSAlex Zinenko results.push_back(result);
12669c5982efSAlex Zinenko return success();
12679c5982efSAlex Zinenko });
12684589dd92SRiver Riddle
12694589dd92SRiver Riddle /// Add the legal set of type materializations.
12704589dd92SRiver Riddle converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
12714589dd92SRiver Riddle ValueRange inputs,
12724589dd92SRiver Riddle Location loc) -> Value {
12734589dd92SRiver Riddle // Allow casting from F64 back to F32.
12744589dd92SRiver Riddle if (!resultType.isF16() && inputs.size() == 1 &&
12754589dd92SRiver Riddle inputs[0].getType().isF64())
12764589dd92SRiver Riddle return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
12774589dd92SRiver Riddle // Allow producing an i32 or i64 from nothing.
12784589dd92SRiver Riddle if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
12794589dd92SRiver Riddle inputs.empty())
12804589dd92SRiver Riddle return builder.create<TestTypeProducerOp>(loc, resultType);
12814589dd92SRiver Riddle // Allow producing an i64 from an integer.
12824589dd92SRiver Riddle if (resultType.isa<IntegerType>() && inputs.size() == 1 &&
12834589dd92SRiver Riddle inputs[0].getType().isa<IntegerType>())
12844589dd92SRiver Riddle return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
12854589dd92SRiver Riddle // Otherwise, fail.
12864589dd92SRiver Riddle return nullptr;
12874589dd92SRiver Riddle });
12884589dd92SRiver Riddle
12894589dd92SRiver Riddle // Initialize the conversion target.
12904589dd92SRiver Riddle mlir::ConversionTarget target(getContext());
12914589dd92SRiver Riddle target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
12929c5982efSAlex Zinenko auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>();
12939c5982efSAlex Zinenko return op.getType().isF64() || op.getType().isInteger(64) ||
12949c5982efSAlex Zinenko (recursiveType &&
12959c5982efSAlex Zinenko recursiveType.getName() == "outer_converted_type");
12964589dd92SRiver Riddle });
129758ceae95SRiver Riddle target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
12984a3460a7SRiver Riddle return converter.isSignatureLegal(op.getFunctionType()) &&
12994589dd92SRiver Riddle converter.isLegal(&op.getBody());
13004589dd92SRiver Riddle });
13014589dd92SRiver Riddle target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
13024589dd92SRiver Riddle // Allow casts from F64 to F32.
13034589dd92SRiver Riddle return (*op.operand_type_begin()).isF64() && op.getType().isF32();
13044589dd92SRiver Riddle });
13054070f305SRiver Riddle target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
13064070f305SRiver Riddle [&](TestSignatureConversionNoConverterOp op) {
13074070f305SRiver Riddle return converter.isLegal(op.getRegion().front().getArgumentTypes());
13084070f305SRiver Riddle });
13094589dd92SRiver Riddle
13104589dd92SRiver Riddle // Initialize the set of rewrite patterns.
1311dc4e913bSChris Lattner RewritePatternSet patterns(&getContext());
1312dc4e913bSChris Lattner patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
13134070f305SRiver Riddle TestSignatureConversionUndo,
13144070f305SRiver Riddle TestTestSignatureConversionNoConverter>(converter,
13154070f305SRiver Riddle &getContext());
1316dc4e913bSChris Lattner patterns.add<TestTypeConversionAnotherProducer>(&getContext());
131758ceae95SRiver Riddle mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
131858ceae95SRiver Riddle patterns, converter);
13194589dd92SRiver Riddle
13203fffffa8SRiver Riddle if (failed(applyPartialConversion(getOperation(), target,
13213fffffa8SRiver Riddle std::move(patterns))))
13224589dd92SRiver Riddle signalPassFailure();
13234589dd92SRiver Riddle }
13244589dd92SRiver Riddle };
1325be0a7e9fSMehdi Amini } // namespace
13264589dd92SRiver Riddle
1327c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
1328d4a53f3bSAlex Zinenko // Test Target Materialization With No Uses
1329d4a53f3bSAlex Zinenko //===----------------------------------------------------------------------===//
1330d4a53f3bSAlex Zinenko
1331d4a53f3bSAlex Zinenko namespace {
1332d4a53f3bSAlex Zinenko struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1333d4a53f3bSAlex Zinenko using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1334d4a53f3bSAlex Zinenko
1335d4a53f3bSAlex Zinenko LogicalResult
matchAndRewrite__anon71cd160d2911::ForwardOperandPattern1336d4a53f3bSAlex Zinenko matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1337d4a53f3bSAlex Zinenko ConversionPatternRewriter &rewriter) const final {
1338d4a53f3bSAlex Zinenko rewriter.replaceOp(op, adaptor.getOperands());
1339d4a53f3bSAlex Zinenko return success();
1340d4a53f3bSAlex Zinenko }
1341d4a53f3bSAlex Zinenko };
1342d4a53f3bSAlex Zinenko
1343d4a53f3bSAlex Zinenko struct TestTargetMaterializationWithNoUses
1344d4a53f3bSAlex Zinenko : public PassWrapper<TestTargetMaterializationWithNoUses,
1345d4a53f3bSAlex Zinenko OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d2911::TestTargetMaterializationWithNoUses13465e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
13475e50dd04SRiver Riddle TestTargetMaterializationWithNoUses)
13485e50dd04SRiver Riddle
1349d4a53f3bSAlex Zinenko StringRef getArgument() const final {
1350d4a53f3bSAlex Zinenko return "test-target-materialization-with-no-uses";
1351d4a53f3bSAlex Zinenko }
getDescription__anon71cd160d2911::TestTargetMaterializationWithNoUses1352d4a53f3bSAlex Zinenko StringRef getDescription() const final {
1353d4a53f3bSAlex Zinenko return "Test a special case of target materialization in DialectConversion";
1354d4a53f3bSAlex Zinenko }
1355d4a53f3bSAlex Zinenko
runOnOperation__anon71cd160d2911::TestTargetMaterializationWithNoUses1356d4a53f3bSAlex Zinenko void runOnOperation() override {
1357d4a53f3bSAlex Zinenko TypeConverter converter;
1358d4a53f3bSAlex Zinenko converter.addConversion([](Type t) { return t; });
1359d4a53f3bSAlex Zinenko converter.addConversion([](IntegerType intTy) -> Type {
1360d4a53f3bSAlex Zinenko if (intTy.getWidth() == 16)
1361d4a53f3bSAlex Zinenko return IntegerType::get(intTy.getContext(), 64);
1362d4a53f3bSAlex Zinenko return intTy;
1363d4a53f3bSAlex Zinenko });
1364d4a53f3bSAlex Zinenko converter.addTargetMaterialization(
1365d4a53f3bSAlex Zinenko [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1366d4a53f3bSAlex Zinenko return builder.create<TestCastOp>(loc, type, inputs).getResult();
1367d4a53f3bSAlex Zinenko });
1368d4a53f3bSAlex Zinenko
1369d4a53f3bSAlex Zinenko ConversionTarget target(getContext());
1370d4a53f3bSAlex Zinenko target.addIllegalOp<TestTypeChangerOp>();
1371d4a53f3bSAlex Zinenko
1372d4a53f3bSAlex Zinenko RewritePatternSet patterns(&getContext());
1373d4a53f3bSAlex Zinenko patterns.add<ForwardOperandPattern>(converter, &getContext());
1374d4a53f3bSAlex Zinenko
1375d4a53f3bSAlex Zinenko if (failed(applyPartialConversion(getOperation(), target,
1376d4a53f3bSAlex Zinenko std::move(patterns))))
1377d4a53f3bSAlex Zinenko signalPassFailure();
1378d4a53f3bSAlex Zinenko }
1379d4a53f3bSAlex Zinenko };
1380d4a53f3bSAlex Zinenko } // namespace
1381d4a53f3bSAlex Zinenko
1382d4a53f3bSAlex Zinenko //===----------------------------------------------------------------------===//
1383c8fb6ee3SRiver Riddle // Test Block Merging
1384c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
1385c8fb6ee3SRiver Riddle
1386e888886cSMaheshRavishankar namespace {
1387e888886cSMaheshRavishankar /// A rewriter pattern that tests that blocks can be merged.
1388e888886cSMaheshRavishankar struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1389e888886cSMaheshRavishankar using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1390e888886cSMaheshRavishankar
1391e888886cSMaheshRavishankar LogicalResult
matchAndRewrite__anon71cd160d2d11::TestMergeBlock1392ef976337SRiver Riddle matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1393e888886cSMaheshRavishankar ConversionPatternRewriter &rewriter) const final {
13946a994233SJacques Pienaar Block &firstBlock = op.getBody().front();
1395e888886cSMaheshRavishankar Operation *branchOp = firstBlock.getTerminator();
13966a994233SJacques Pienaar Block *secondBlock = &*(std::next(op.getBody().begin()));
1397e888886cSMaheshRavishankar auto succOperands = branchOp->getOperands();
1398e888886cSMaheshRavishankar SmallVector<Value, 2> replacements(succOperands);
1399e888886cSMaheshRavishankar rewriter.eraseOp(branchOp);
1400e888886cSMaheshRavishankar rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1401e888886cSMaheshRavishankar rewriter.updateRootInPlace(op, [] {});
1402e888886cSMaheshRavishankar return success();
1403e888886cSMaheshRavishankar }
1404e888886cSMaheshRavishankar };
1405e888886cSMaheshRavishankar
1406e888886cSMaheshRavishankar /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1407e888886cSMaheshRavishankar struct TestUndoBlocksMerge : public ConversionPattern {
TestUndoBlocksMerge__anon71cd160d2d11::TestUndoBlocksMerge1408e888886cSMaheshRavishankar TestUndoBlocksMerge(MLIRContext *ctx)
1409e888886cSMaheshRavishankar : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1410e888886cSMaheshRavishankar LogicalResult
matchAndRewrite__anon71cd160d2d11::TestUndoBlocksMerge1411e888886cSMaheshRavishankar matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1412e888886cSMaheshRavishankar ConversionPatternRewriter &rewriter) const final {
1413e888886cSMaheshRavishankar Block &firstBlock = op->getRegion(0).front();
1414e888886cSMaheshRavishankar Operation *branchOp = firstBlock.getTerminator();
1415e888886cSMaheshRavishankar Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1416e888886cSMaheshRavishankar rewriter.setInsertionPointToStart(secondBlock);
1417e888886cSMaheshRavishankar rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1418e888886cSMaheshRavishankar auto succOperands = branchOp->getOperands();
1419e888886cSMaheshRavishankar SmallVector<Value, 2> replacements(succOperands);
1420e888886cSMaheshRavishankar rewriter.eraseOp(branchOp);
1421e888886cSMaheshRavishankar rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1422e888886cSMaheshRavishankar rewriter.updateRootInPlace(op, [] {});
1423e888886cSMaheshRavishankar return success();
1424e888886cSMaheshRavishankar }
1425e888886cSMaheshRavishankar };
1426e888886cSMaheshRavishankar
1427e888886cSMaheshRavishankar /// A rewrite mechanism to inline the body of the op into its parent, when both
1428e888886cSMaheshRavishankar /// ops can have a single block.
1429e888886cSMaheshRavishankar struct TestMergeSingleBlockOps
1430e888886cSMaheshRavishankar : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1431e888886cSMaheshRavishankar using OpConversionPattern<
1432e888886cSMaheshRavishankar SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1433e888886cSMaheshRavishankar
1434e888886cSMaheshRavishankar LogicalResult
matchAndRewrite__anon71cd160d2d11::TestMergeSingleBlockOps1435ef976337SRiver Riddle matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1436e888886cSMaheshRavishankar ConversionPatternRewriter &rewriter) const final {
1437e888886cSMaheshRavishankar SingleBlockImplicitTerminatorOp parentOp =
14380bf4a82aSChristian Sigg op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1439e888886cSMaheshRavishankar if (!parentOp)
1440e888886cSMaheshRavishankar return failure();
14416a994233SJacques Pienaar Block &innerBlock = op.getRegion().front();
1442e888886cSMaheshRavishankar TerminatorOp innerTerminator =
1443e888886cSMaheshRavishankar cast<TerminatorOp>(innerBlock.getTerminator());
14449c7b0c4aSRahul Joshi rewriter.mergeBlockBefore(&innerBlock, op);
1445e888886cSMaheshRavishankar rewriter.eraseOp(innerTerminator);
1446e888886cSMaheshRavishankar rewriter.eraseOp(op);
1447e888886cSMaheshRavishankar rewriter.updateRootInPlace(op, [] {});
1448e888886cSMaheshRavishankar return success();
1449e888886cSMaheshRavishankar }
1450e888886cSMaheshRavishankar };
1451e888886cSMaheshRavishankar
1452e888886cSMaheshRavishankar struct TestMergeBlocksPatternDriver
1453e888886cSMaheshRavishankar : public PassWrapper<TestMergeBlocksPatternDriver,
1454e888886cSMaheshRavishankar OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d2d11::TestMergeBlocksPatternDriver14555e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
14565e50dd04SRiver Riddle
1457b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-merge-blocks"; }
getDescription__anon71cd160d2d11::TestMergeBlocksPatternDriver1458b5e22e6dSMehdi Amini StringRef getDescription() const final {
1459b5e22e6dSMehdi Amini return "Test Merging operation in ConversionPatternRewriter";
1460b5e22e6dSMehdi Amini }
runOnOperation__anon71cd160d2d11::TestMergeBlocksPatternDriver1461e888886cSMaheshRavishankar void runOnOperation() override {
1462e888886cSMaheshRavishankar MLIRContext *context = &getContext();
1463dc4e913bSChris Lattner mlir::RewritePatternSet patterns(context);
1464dc4e913bSChris Lattner patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1465e888886cSMaheshRavishankar context);
1466e888886cSMaheshRavishankar ConversionTarget target(*context);
146758ceae95SRiver Riddle target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1468973ddb7dSMehdi Amini TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1469e888886cSMaheshRavishankar target.addIllegalOp<ILLegalOpF>();
1470e888886cSMaheshRavishankar
1471e888886cSMaheshRavishankar /// Expect the op to have a single block after legalization.
1472e888886cSMaheshRavishankar target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1473e888886cSMaheshRavishankar [&](TestMergeBlocksOp op) -> bool {
14746a994233SJacques Pienaar return llvm::hasSingleElement(op.getBody());
1475e888886cSMaheshRavishankar });
1476e888886cSMaheshRavishankar
1477e888886cSMaheshRavishankar /// Only allow `test.br` within test.merge_blocks op.
1478e888886cSMaheshRavishankar target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
14790bf4a82aSChristian Sigg return op->getParentOfType<TestMergeBlocksOp>();
1480e888886cSMaheshRavishankar });
1481e888886cSMaheshRavishankar
1482e888886cSMaheshRavishankar /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1483e888886cSMaheshRavishankar /// inlined.
1484e888886cSMaheshRavishankar target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1485e888886cSMaheshRavishankar [&](SingleBlockImplicitTerminatorOp op) -> bool {
14860bf4a82aSChristian Sigg return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1487e888886cSMaheshRavishankar });
1488e888886cSMaheshRavishankar
1489e888886cSMaheshRavishankar DenseSet<Operation *> unlegalizedOps;
14903fffffa8SRiver Riddle (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1491e888886cSMaheshRavishankar &unlegalizedOps);
1492e888886cSMaheshRavishankar for (auto *op : unlegalizedOps)
1493e888886cSMaheshRavishankar op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1494e888886cSMaheshRavishankar }
1495e888886cSMaheshRavishankar };
1496e888886cSMaheshRavishankar } // namespace
1497e888886cSMaheshRavishankar
14984589dd92SRiver Riddle //===----------------------------------------------------------------------===//
1499c8fb6ee3SRiver Riddle // Test Selective Replacement
1500c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
1501c8fb6ee3SRiver Riddle
1502c8fb6ee3SRiver Riddle namespace {
1503c8fb6ee3SRiver Riddle /// A rewrite mechanism to inline the body of the op into its parent, when both
1504c8fb6ee3SRiver Riddle /// ops can have a single block.
1505c8fb6ee3SRiver Riddle struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
1506c8fb6ee3SRiver Riddle using OpRewritePattern<TestCastOp>::OpRewritePattern;
1507c8fb6ee3SRiver Riddle
matchAndRewrite__anon71cd160d3411::TestSelectiveOpReplacementPattern1508c8fb6ee3SRiver Riddle LogicalResult matchAndRewrite(TestCastOp op,
1509c8fb6ee3SRiver Riddle PatternRewriter &rewriter) const final {
1510c8fb6ee3SRiver Riddle if (op.getNumOperands() != 2)
1511c8fb6ee3SRiver Riddle return failure();
1512c8fb6ee3SRiver Riddle OperandRange operands = op.getOperands();
1513c8fb6ee3SRiver Riddle
1514c8fb6ee3SRiver Riddle // Replace non-terminator uses with the first operand.
1515c8fb6ee3SRiver Riddle rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
1516fe7c0d90SRiver Riddle return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
1517c8fb6ee3SRiver Riddle });
1518c8fb6ee3SRiver Riddle // Replace everything else with the second operand if the operation isn't
1519c8fb6ee3SRiver Riddle // dead.
1520c8fb6ee3SRiver Riddle rewriter.replaceOp(op, op.getOperand(1));
1521c8fb6ee3SRiver Riddle return success();
1522c8fb6ee3SRiver Riddle }
1523c8fb6ee3SRiver Riddle };
1524c8fb6ee3SRiver Riddle
1525c8fb6ee3SRiver Riddle struct TestSelectiveReplacementPatternDriver
1526c8fb6ee3SRiver Riddle : public PassWrapper<TestSelectiveReplacementPatternDriver,
1527c8fb6ee3SRiver Riddle OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon71cd160d3411::TestSelectiveReplacementPatternDriver15285e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
15295e50dd04SRiver Riddle TestSelectiveReplacementPatternDriver)
15305e50dd04SRiver Riddle
1531b5e22e6dSMehdi Amini StringRef getArgument() const final {
1532b5e22e6dSMehdi Amini return "test-pattern-selective-replacement";
1533b5e22e6dSMehdi Amini }
getDescription__anon71cd160d3411::TestSelectiveReplacementPatternDriver1534b5e22e6dSMehdi Amini StringRef getDescription() const final {
1535b5e22e6dSMehdi Amini return "Test selective replacement in the PatternRewriter";
1536b5e22e6dSMehdi Amini }
runOnOperation__anon71cd160d3411::TestSelectiveReplacementPatternDriver1537c8fb6ee3SRiver Riddle void runOnOperation() override {
1538c8fb6ee3SRiver Riddle MLIRContext *context = &getContext();
1539dc4e913bSChris Lattner mlir::RewritePatternSet patterns(context);
1540dc4e913bSChris Lattner patterns.add<TestSelectiveOpReplacementPattern>(context);
1541e21adfa3SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
1542c8fb6ee3SRiver Riddle std::move(patterns));
1543c8fb6ee3SRiver Riddle }
1544c8fb6ee3SRiver Riddle };
1545c8fb6ee3SRiver Riddle } // namespace
1546c8fb6ee3SRiver Riddle
1547c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===//
15484589dd92SRiver Riddle // PassRegistration
15494589dd92SRiver Riddle //===----------------------------------------------------------------------===//
15504589dd92SRiver Riddle
1551fec6c5acSUday Bondhugula namespace mlir {
155272c65b69SAlexander Belyaev namespace test {
registerPatternsTestPass()1553fec6c5acSUday Bondhugula void registerPatternsTestPass() {
1554b5e22e6dSMehdi Amini PassRegistration<TestReturnTypeDriver>();
1555fec6c5acSUday Bondhugula
1556b5e22e6dSMehdi Amini PassRegistration<TestDerivedAttributeDriver>();
15579ba37b3bSJacques Pienaar
1558b5e22e6dSMehdi Amini PassRegistration<TestPatternDriver>();
1559*ba3a9f51SChia-hung Duan PassRegistration<TestStrictPatternDriver>();
1560fec6c5acSUday Bondhugula
1561b5e22e6dSMehdi Amini PassRegistration<TestLegalizePatternDriver>([] {
1562b5e22e6dSMehdi Amini return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
1563fec6c5acSUday Bondhugula });
1564fec6c5acSUday Bondhugula
1565b5e22e6dSMehdi Amini PassRegistration<TestRemappedValue>();
156680d7ac3bSRiver Riddle
1567b5e22e6dSMehdi Amini PassRegistration<TestUnknownRootOpDriver>();
15684589dd92SRiver Riddle
1569b5e22e6dSMehdi Amini PassRegistration<TestTypeConversionDriver>();
1570d4a53f3bSAlex Zinenko PassRegistration<TestTargetMaterializationWithNoUses>();
1571e888886cSMaheshRavishankar
157288bc24a7SMathieu Fehr PassRegistration<TestRewriteDynamicOpDriver>();
157388bc24a7SMathieu Fehr
1574b5e22e6dSMehdi Amini PassRegistration<TestMergeBlocksPatternDriver>();
1575b5e22e6dSMehdi Amini PassRegistration<TestSelectiveReplacementPatternDriver>();
1576fec6c5acSUday Bondhugula }
157772c65b69SAlexander Belyaev } // namespace test
1578fec6c5acSUday Bondhugula } // namespace mlir
1579