1 //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Pass/PassManager.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/Pass/Pass.h"
14 #include "gtest/gtest.h"
15 
16 #include <memory>
17 
18 using namespace mlir;
19 using namespace mlir::detail;
20 
21 namespace {
22 /// Analysis that operates on any operation.
23 struct GenericAnalysis {
24   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenericAnalysis)
25 
GenericAnalysis__anonde9178ad0111::GenericAnalysis26   GenericAnalysis(Operation *op) : isFunc(isa<func::FuncOp>(op)) {}
27   const bool isFunc;
28 };
29 
30 /// Analysis that operates on a specific operation.
31 struct OpSpecificAnalysis {
32   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpSpecificAnalysis)
33 
OpSpecificAnalysis__anonde9178ad0111::OpSpecificAnalysis34   OpSpecificAnalysis(func::FuncOp op) : isSecret(op.getName() == "secret") {}
35   const bool isSecret;
36 };
37 
38 /// Simple pass to annotate a func::FuncOp with the results of analysis.
39 struct AnnotateFunctionPass
40     : public PassWrapper<AnnotateFunctionPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonde9178ad0111::AnnotateFunctionPass41   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateFunctionPass)
42 
43   void runOnOperation() override {
44     func::FuncOp op = getOperation();
45     Builder builder(op->getParentOfType<ModuleOp>());
46 
47     auto &ga = getAnalysis<GenericAnalysis>();
48     auto &sa = getAnalysis<OpSpecificAnalysis>();
49 
50     op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
51     op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
52   }
53 };
54 
TEST(PassManagerTest,OpSpecificAnalysis)55 TEST(PassManagerTest, OpSpecificAnalysis) {
56   MLIRContext context;
57   context.loadDialect<func::FuncDialect>();
58   Builder builder(&context);
59 
60   // Create a module with 2 functions.
61   OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
62   for (StringRef name : {"secret", "not_secret"}) {
63     auto func =
64         func::FuncOp::create(builder.getUnknownLoc(), name,
65                              builder.getFunctionType(llvm::None, llvm::None));
66     func.setPrivate();
67     module->push_back(func);
68   }
69 
70   // Instantiate and run our pass.
71   PassManager pm(&context);
72   pm.addNestedPass<func::FuncOp>(std::make_unique<AnnotateFunctionPass>());
73   LogicalResult result = pm.run(module.get());
74   EXPECT_TRUE(succeeded(result));
75 
76   // Verify that each function got annotated with expected attributes.
77   for (func::FuncOp func : module->getOps<func::FuncOp>()) {
78     ASSERT_TRUE(func->getAttr("isFunc").isa<BoolAttr>());
79     EXPECT_TRUE(func->getAttr("isFunc").cast<BoolAttr>().getValue());
80 
81     bool isSecret = func.getName() == "secret";
82     ASSERT_TRUE(func->getAttr("isSecret").isa<BoolAttr>());
83     EXPECT_EQ(func->getAttr("isSecret").cast<BoolAttr>().getValue(), isSecret);
84   }
85 }
86 
87 namespace {
88 struct InvalidPass : Pass {
89   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass)
90 
InvalidPass__anonde9178ad0111::__anonde9178ad0211::InvalidPass91   InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
getName__anonde9178ad0111::__anonde9178ad0211::InvalidPass92   StringRef getName() const override { return "Invalid Pass"; }
runOnOperation__anonde9178ad0111::__anonde9178ad0211::InvalidPass93   void runOnOperation() override {}
canScheduleOn__anonde9178ad0111::__anonde9178ad0211::InvalidPass94   bool canScheduleOn(RegisteredOperationName opName) const override {
95     return true;
96   }
97 
98   /// A clone method to create a copy of this pass.
clonePass__anonde9178ad0111::__anonde9178ad0211::InvalidPass99   std::unique_ptr<Pass> clonePass() const override {
100     return std::make_unique<InvalidPass>(
101         *static_cast<const InvalidPass *>(this));
102   }
103 };
104 } // namespace
105 
TEST(PassManagerTest,InvalidPass)106 TEST(PassManagerTest, InvalidPass) {
107   MLIRContext context;
108   context.allowUnregisteredDialects();
109 
110   // Create a module
111   OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
112 
113   // Add a single "invalid_op" operation
114   OpBuilder builder(&module->getBodyRegion());
115   OperationState state(UnknownLoc::get(&context), "invalid_op");
116   builder.insert(Operation::create(state));
117 
118   // Register a diagnostic handler to capture the diagnostic so that we can
119   // check it later.
120   std::unique_ptr<Diagnostic> diagnostic;
121   context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
122     diagnostic = std::make_unique<Diagnostic>(std::move(diag));
123   });
124 
125   // Instantiate and run our pass.
126   PassManager pm(&context);
127   pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
128   LogicalResult result = pm.run(module.get());
129   EXPECT_TRUE(failed(result));
130   ASSERT_TRUE(diagnostic.get() != nullptr);
131   EXPECT_EQ(
132       diagnostic->str(),
133       "'invalid_op' op trying to schedule a pass on an unregistered operation");
134 
135   // Check that clearing the pass manager effectively removed the pass.
136   pm.clear();
137   result = pm.run(module.get());
138   EXPECT_TRUE(succeeded(result));
139 
140   // Check that adding the pass at the top-level triggers a fatal error.
141   ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()), "");
142 }
143 
144 } // namespace
145