1 //===- PassManagerTest.cpp - PassManager unit tests -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Pass/PassManager.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/BuiltinOps.h" 12 #include "mlir/Pass/Pass.h" 13 #include "gtest/gtest.h" 14 15 #include <memory> 16 17 using namespace mlir; 18 using namespace mlir::detail; 19 20 namespace { 21 /// Analysis that operates on any operation. 22 struct GenericAnalysis { 23 GenericAnalysis(Operation *op) : isFunc(isa<FuncOp>(op)) {} 24 const bool isFunc; 25 }; 26 27 /// Analysis that operates on a specific operation. 28 struct OpSpecificAnalysis { 29 OpSpecificAnalysis(FuncOp op) : isSecret(op.getName() == "secret") {} 30 const bool isSecret; 31 }; 32 33 /// Simple pass to annotate a FuncOp with the results of analysis. 34 /// Note: not using FunctionPass as it skip external functions. 35 struct AnnotateFunctionPass 36 : public PassWrapper<AnnotateFunctionPass, OperationPass<FuncOp>> { 37 void runOnOperation() override { 38 FuncOp op = getOperation(); 39 Builder builder(op->getParentOfType<ModuleOp>()); 40 41 auto &ga = getAnalysis<GenericAnalysis>(); 42 auto &sa = getAnalysis<OpSpecificAnalysis>(); 43 44 op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc)); 45 op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret)); 46 } 47 }; 48 49 TEST(PassManagerTest, OpSpecificAnalysis) { 50 MLIRContext context; 51 Builder builder(&context); 52 53 // Create a module with 2 functions. 54 OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); 55 for (StringRef name : {"secret", "not_secret"}) { 56 FuncOp func = 57 FuncOp::create(builder.getUnknownLoc(), name, 58 builder.getFunctionType(llvm::None, llvm::None)); 59 func.setPrivate(); 60 module->push_back(func); 61 } 62 63 // Instantiate and run our pass. 64 PassManager pm(&context); 65 pm.addNestedPass<FuncOp>(std::make_unique<AnnotateFunctionPass>()); 66 LogicalResult result = pm.run(module.get()); 67 EXPECT_TRUE(succeeded(result)); 68 69 // Verify that each function got annotated with expected attributes. 70 for (FuncOp func : module->getOps<FuncOp>()) { 71 ASSERT_TRUE(func->getAttr("isFunc").isa<BoolAttr>()); 72 EXPECT_TRUE(func->getAttr("isFunc").cast<BoolAttr>().getValue()); 73 74 bool isSecret = func.getName() == "secret"; 75 ASSERT_TRUE(func->getAttr("isSecret").isa<BoolAttr>()); 76 EXPECT_EQ(func->getAttr("isSecret").cast<BoolAttr>().getValue(), isSecret); 77 } 78 } 79 80 namespace { 81 struct InvalidPass : Pass { 82 InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {} 83 StringRef getName() const override { return "Invalid Pass"; } 84 void runOnOperation() override {} 85 86 /// A clone method to create a copy of this pass. 87 std::unique_ptr<Pass> clonePass() const override { 88 return std::make_unique<InvalidPass>( 89 *static_cast<const InvalidPass *>(this)); 90 } 91 }; 92 } // namespace 93 94 TEST(PassManagerTest, InvalidPass) { 95 MLIRContext context; 96 context.allowUnregisteredDialects(); 97 98 // Create a module 99 OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); 100 101 // Add a single "invalid_op" operation 102 OpBuilder builder(&module->getBodyRegion()); 103 OperationState state(UnknownLoc::get(&context), "invalid_op"); 104 builder.insert(Operation::create(state)); 105 106 // Register a diagnostic handler to capture the diagnostic so that we can 107 // check it later. 108 std::unique_ptr<Diagnostic> diagnostic; 109 context.getDiagEngine().registerHandler([&](Diagnostic &diag) { 110 diagnostic = std::make_unique<Diagnostic>(std::move(diag)); 111 }); 112 113 // Instantiate and run our pass. 114 PassManager pm(&context); 115 pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>()); 116 LogicalResult result = pm.run(module.get()); 117 EXPECT_TRUE(failed(result)); 118 ASSERT_TRUE(diagnostic.get() != nullptr); 119 EXPECT_EQ( 120 diagnostic->str(), 121 "'invalid_op' op trying to schedule a pass on an unregistered operation"); 122 123 // Check that clearing the pass manager effectively removed the pass. 124 pm.clear(); 125 result = pm.run(module.get()); 126 EXPECT_TRUE(succeeded(result)); 127 128 // Check that adding the pass at the top-level triggers a fatal error. 129 ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()), ""); 130 } 131 132 } // namespace 133