1 //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Pass/PassManager.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/Pass/Pass.h"
13 #include "gtest/gtest.h"
14 
15 #include <memory>
16 
17 using namespace mlir;
18 using namespace mlir::detail;
19 
20 namespace {
21 /// Analysis that operates on any operation.
22 struct GenericAnalysis {
23   GenericAnalysis(Operation *op) : isFunc(isa<FuncOp>(op)) {}
24   const bool isFunc;
25 };
26 
27 /// Analysis that operates on a specific operation.
28 struct OpSpecificAnalysis {
29   OpSpecificAnalysis(FuncOp op) : isSecret(op.getName() == "secret") {}
30   const bool isSecret;
31 };
32 
33 /// Simple pass to annotate a FuncOp with the results of analysis.
34 /// Note: not using FunctionPass as it skip external functions.
35 struct AnnotateFunctionPass
36     : public PassWrapper<AnnotateFunctionPass, OperationPass<FuncOp>> {
37   void runOnOperation() override {
38     FuncOp op = getOperation();
39     Builder builder(op->getParentOfType<ModuleOp>());
40 
41     auto &ga = getAnalysis<GenericAnalysis>();
42     auto &sa = getAnalysis<OpSpecificAnalysis>();
43 
44     op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
45     op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
46   }
47 };
48 
49 TEST(PassManagerTest, OpSpecificAnalysis) {
50   MLIRContext context;
51   Builder builder(&context);
52 
53   // Create a module with 2 functions.
54   OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
55   for (StringRef name : {"secret", "not_secret"}) {
56     FuncOp func =
57         FuncOp::create(builder.getUnknownLoc(), name,
58                        builder.getFunctionType(llvm::None, llvm::None));
59     func.setPrivate();
60     module->push_back(func);
61   }
62 
63   // Instantiate and run our pass.
64   PassManager pm(&context);
65   pm.addNestedPass<FuncOp>(std::make_unique<AnnotateFunctionPass>());
66   LogicalResult result = pm.run(module.get());
67   EXPECT_TRUE(succeeded(result));
68 
69   // Verify that each function got annotated with expected attributes.
70   for (FuncOp func : module->getOps<FuncOp>()) {
71     ASSERT_TRUE(func->getAttr("isFunc").isa<BoolAttr>());
72     EXPECT_TRUE(func->getAttr("isFunc").cast<BoolAttr>().getValue());
73 
74     bool isSecret = func.getName() == "secret";
75     ASSERT_TRUE(func->getAttr("isSecret").isa<BoolAttr>());
76     EXPECT_EQ(func->getAttr("isSecret").cast<BoolAttr>().getValue(), isSecret);
77   }
78 }
79 
80 namespace {
81 struct InvalidPass : Pass {
82   InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
83   StringRef getName() const override { return "Invalid Pass"; }
84   void runOnOperation() override {}
85 
86   /// A clone method to create a copy of this pass.
87   std::unique_ptr<Pass> clonePass() const override {
88     return std::make_unique<InvalidPass>(
89         *static_cast<const InvalidPass *>(this));
90   }
91 };
92 } // namespace
93 
94 TEST(PassManagerTest, InvalidPass) {
95   MLIRContext context;
96   context.allowUnregisteredDialects();
97 
98   // Create a module
99   OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
100 
101   // Add a single "invalid_op" operation
102   OpBuilder builder(&module->getBodyRegion());
103   OperationState state(UnknownLoc::get(&context), "invalid_op");
104   builder.insert(Operation::create(state));
105 
106   // Register a diagnostic handler to capture the diagnostic so that we can
107   // check it later.
108   std::unique_ptr<Diagnostic> diagnostic;
109   context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
110     diagnostic = std::make_unique<Diagnostic>(std::move(diag));
111   });
112 
113   // Instantiate and run our pass.
114   PassManager pm(&context);
115   pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
116   LogicalResult result = pm.run(module.get());
117   EXPECT_TRUE(failed(result));
118   ASSERT_TRUE(diagnostic.get() != nullptr);
119   EXPECT_EQ(
120       diagnostic->str(),
121       "'invalid_op' op trying to schedule a pass on an unregistered operation");
122 
123   // Check that clearing the pass manager effectively removed the pass.
124   pm.clear();
125   result = pm.run(module.get());
126   EXPECT_TRUE(succeeded(result));
127 
128   // Check that adding the pass at the top-level triggers a fatal error.
129   ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()), "");
130 }
131 
132 } // namespace
133