1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/Parser.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13 #include "mlir/Transforms/Passes.h"
14 #include "gtest/gtest.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 
20 struct DisabledPattern : public RewritePattern {
21   DisabledPattern(MLIRContext *context)
22       : RewritePattern("test.foo", /*benefit=*/0, context,
23                        /*generatedNamed=*/{}) {
24     setDebugName("DisabledPattern");
25   }
26 
27   LogicalResult matchAndRewrite(Operation *op,
28                                 PatternRewriter &rewriter) const override {
29     if (op->getNumResults() != 1)
30       return failure();
31     rewriter.eraseOp(op);
32     return success();
33   }
34 };
35 
36 struct EnabledPattern : public RewritePattern {
37   EnabledPattern(MLIRContext *context)
38       : RewritePattern("test.foo", /*benefit=*/0, context,
39                        /*generatedNamed=*/{}) {
40     setDebugName("EnabledPattern");
41   }
42 
43   LogicalResult matchAndRewrite(Operation *op,
44                                 PatternRewriter &rewriter) const override {
45     if (op->getNumResults() == 1)
46       return failure();
47     rewriter.eraseOp(op);
48     return success();
49   }
50 };
51 
52 struct TestDialect : public Dialect {
53   static StringRef getDialectNamespace() { return "test"; }
54 
55   TestDialect(MLIRContext *context)
56       : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {
57     allowUnknownOperations();
58   }
59 
60   void getCanonicalizationPatterns(RewritePatternSet &results) const override {
61     results.insert<DisabledPattern, EnabledPattern>(results.getContext());
62   }
63 };
64 
65 TEST(CanonicalizerTest, TestDisablePatterns) {
66   MLIRContext context;
67   context.getOrLoadDialect<TestDialect>();
68   PassManager mgr(&context);
69   mgr.addPass(
70       createCanonicalizerPass(GreedyRewriteConfig(), {"DisabledPattern"}));
71 
72   const char *const code = R"mlir(
73     %0:2 = "test.foo"() {sym_name = "A"} : () -> (i32, i32)
74     %1 = "test.foo"() {sym_name = "B"} : () -> (f32)
75   )mlir";
76 
77   OwningModuleRef module = mlir::parseSourceString(code, &context);
78   ASSERT_TRUE(succeeded(mgr.run(*module)));
79 
80   EXPECT_TRUE(module->lookupSymbol("B"));
81   EXPECT_FALSE(module->lookupSymbol("A"));
82 }
83 
84 } // end anonymous namespace
85