1 //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Pass/PassManager.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/Pass/Pass.h"
14 #include "gtest/gtest.h"
15
16 #include <memory>
17
18 using namespace mlir;
19 using namespace mlir::detail;
20
21 namespace {
22 /// Analysis that operates on any operation.
23 struct GenericAnalysis {
24 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenericAnalysis)
25
GenericAnalysis__anonde9178ad0111::GenericAnalysis26 GenericAnalysis(Operation *op) : isFunc(isa<func::FuncOp>(op)) {}
27 const bool isFunc;
28 };
29
30 /// Analysis that operates on a specific operation.
31 struct OpSpecificAnalysis {
32 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpSpecificAnalysis)
33
OpSpecificAnalysis__anonde9178ad0111::OpSpecificAnalysis34 OpSpecificAnalysis(func::FuncOp op) : isSecret(op.getName() == "secret") {}
35 const bool isSecret;
36 };
37
38 /// Simple pass to annotate a func::FuncOp with the results of analysis.
39 struct AnnotateFunctionPass
40 : public PassWrapper<AnnotateFunctionPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonde9178ad0111::AnnotateFunctionPass41 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateFunctionPass)
42
43 void runOnOperation() override {
44 func::FuncOp op = getOperation();
45 Builder builder(op->getParentOfType<ModuleOp>());
46
47 auto &ga = getAnalysis<GenericAnalysis>();
48 auto &sa = getAnalysis<OpSpecificAnalysis>();
49
50 op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
51 op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
52 }
53 };
54
TEST(PassManagerTest,OpSpecificAnalysis)55 TEST(PassManagerTest, OpSpecificAnalysis) {
56 MLIRContext context;
57 context.loadDialect<func::FuncDialect>();
58 Builder builder(&context);
59
60 // Create a module with 2 functions.
61 OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
62 for (StringRef name : {"secret", "not_secret"}) {
63 auto func =
64 func::FuncOp::create(builder.getUnknownLoc(), name,
65 builder.getFunctionType(llvm::None, llvm::None));
66 func.setPrivate();
67 module->push_back(func);
68 }
69
70 // Instantiate and run our pass.
71 PassManager pm(&context);
72 pm.addNestedPass<func::FuncOp>(std::make_unique<AnnotateFunctionPass>());
73 LogicalResult result = pm.run(module.get());
74 EXPECT_TRUE(succeeded(result));
75
76 // Verify that each function got annotated with expected attributes.
77 for (func::FuncOp func : module->getOps<func::FuncOp>()) {
78 ASSERT_TRUE(func->getAttr("isFunc").isa<BoolAttr>());
79 EXPECT_TRUE(func->getAttr("isFunc").cast<BoolAttr>().getValue());
80
81 bool isSecret = func.getName() == "secret";
82 ASSERT_TRUE(func->getAttr("isSecret").isa<BoolAttr>());
83 EXPECT_EQ(func->getAttr("isSecret").cast<BoolAttr>().getValue(), isSecret);
84 }
85 }
86
87 namespace {
88 struct InvalidPass : Pass {
89 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass)
90
InvalidPass__anonde9178ad0111::__anonde9178ad0211::InvalidPass91 InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
getName__anonde9178ad0111::__anonde9178ad0211::InvalidPass92 StringRef getName() const override { return "Invalid Pass"; }
runOnOperation__anonde9178ad0111::__anonde9178ad0211::InvalidPass93 void runOnOperation() override {}
canScheduleOn__anonde9178ad0111::__anonde9178ad0211::InvalidPass94 bool canScheduleOn(RegisteredOperationName opName) const override {
95 return true;
96 }
97
98 /// A clone method to create a copy of this pass.
clonePass__anonde9178ad0111::__anonde9178ad0211::InvalidPass99 std::unique_ptr<Pass> clonePass() const override {
100 return std::make_unique<InvalidPass>(
101 *static_cast<const InvalidPass *>(this));
102 }
103 };
104 } // namespace
105
TEST(PassManagerTest,InvalidPass)106 TEST(PassManagerTest, InvalidPass) {
107 MLIRContext context;
108 context.allowUnregisteredDialects();
109
110 // Create a module
111 OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
112
113 // Add a single "invalid_op" operation
114 OpBuilder builder(&module->getBodyRegion());
115 OperationState state(UnknownLoc::get(&context), "invalid_op");
116 builder.insert(Operation::create(state));
117
118 // Register a diagnostic handler to capture the diagnostic so that we can
119 // check it later.
120 std::unique_ptr<Diagnostic> diagnostic;
121 context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
122 diagnostic = std::make_unique<Diagnostic>(std::move(diag));
123 });
124
125 // Instantiate and run our pass.
126 PassManager pm(&context);
127 pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
128 LogicalResult result = pm.run(module.get());
129 EXPECT_TRUE(failed(result));
130 ASSERT_TRUE(diagnostic.get() != nullptr);
131 EXPECT_EQ(
132 diagnostic->str(),
133 "'invalid_op' op trying to schedule a pass on an unregistered operation");
134
135 // Check that clearing the pass manager effectively removed the pass.
136 pm.clear();
137 result = pm.run(module.get());
138 EXPECT_TRUE(succeeded(result));
139
140 // Check that adding the pass at the top-level triggers a fatal error.
141 ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()), "");
142 }
143
144 } // namespace
145