1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/PatternMatch.h" 10 #include "mlir/Parser.h" 11 #include "mlir/Pass/PassManager.h" 12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 13 #include "mlir/Transforms/Passes.h" 14 #include "gtest/gtest.h" 15 16 using namespace mlir; 17 18 namespace { 19 20 struct DisabledPattern : public RewritePattern { 21 DisabledPattern(MLIRContext *context) 22 : RewritePattern("test.foo", /*benefit=*/0, context, 23 /*generatedNamed=*/{}) { 24 setDebugName("DisabledPattern"); 25 } 26 27 LogicalResult matchAndRewrite(Operation *op, 28 PatternRewriter &rewriter) const override { 29 if (op->getNumResults() != 1) 30 return failure(); 31 rewriter.eraseOp(op); 32 return success(); 33 } 34 }; 35 36 struct EnabledPattern : public RewritePattern { 37 EnabledPattern(MLIRContext *context) 38 : RewritePattern("test.foo", /*benefit=*/0, context, 39 /*generatedNamed=*/{}) { 40 setDebugName("EnabledPattern"); 41 } 42 43 LogicalResult matchAndRewrite(Operation *op, 44 PatternRewriter &rewriter) const override { 45 if (op->getNumResults() == 1) 46 return failure(); 47 rewriter.eraseOp(op); 48 return success(); 49 } 50 }; 51 52 struct TestDialect : public Dialect { 53 static StringRef getDialectNamespace() { return "test"; } 54 55 TestDialect(MLIRContext *context) 56 : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) { 57 allowUnknownOperations(); 58 } 59 60 void getCanonicalizationPatterns(RewritePatternSet &results) const override { 61 results.insert<DisabledPattern, EnabledPattern>(results.getContext()); 62 } 63 }; 64 65 TEST(CanonicalizerTest, TestDisablePatterns) { 66 MLIRContext context; 67 context.getOrLoadDialect<TestDialect>(); 68 PassManager mgr(&context); 69 mgr.addPass( 70 createCanonicalizerPass(GreedyRewriteConfig(), {"DisabledPattern"})); 71 72 const char *const code = R"mlir( 73 %0:2 = "test.foo"() {sym_name = "A"} : () -> (i32, i32) 74 %1 = "test.foo"() {sym_name = "B"} : () -> (f32) 75 )mlir"; 76 77 OwningModuleRef module = mlir::parseSourceString(code, &context); 78 ASSERT_TRUE(succeeded(mgr.run(*module))); 79 80 EXPECT_TRUE(module->lookupSymbol("B")); 81 EXPECT_FALSE(module->lookupSymbol("A")); 82 } 83 84 } // end anonymous namespace 85