1 //===------ TestDynamicPipeline.cpp --- dynamic pipeline test pass --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test the dynamic pipeline feature.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Pass/PassManager.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 
20 class TestDynamicPipelinePass
21     : public PassWrapper<TestDynamicPipelinePass, OperationPass<>> {
22 public:
23   void getDependentDialects(DialectRegistry &registry) const override {
24     OpPassManager pm(ModuleOp::getOperationName(),
25                      OpPassManager::Nesting::Implicit);
26     (void)parsePassPipeline(pipeline, pm, llvm::errs());
27     pm.getDependentDialects(registry);
28   }
29 
30   TestDynamicPipelinePass(){};
31   TestDynamicPipelinePass(const TestDynamicPipelinePass &) {}
32 
33   void runOnOperation() override {
34     Operation *currentOp = getOperation();
35 
36     llvm::errs() << "Dynamic execute '" << pipeline << "' on "
37                  << currentOp->getName() << "\n";
38     if (pipeline.empty()) {
39       llvm::errs() << "Empty pipeline\n";
40       return;
41     }
42     auto symbolOp = dyn_cast<SymbolOpInterface>(currentOp);
43     if (!symbolOp) {
44       currentOp->emitWarning()
45           << "Ignoring because not implementing SymbolOpInterface\n";
46       return;
47     }
48 
49     auto opName = symbolOp.getName();
50     if (!opNames.empty() && !llvm::is_contained(opNames, opName)) {
51       llvm::errs() << "dynamic-pipeline skip op name: " << opName << "\n";
52       return;
53     }
54     if (!pm) {
55       pm = std::make_unique<OpPassManager>(currentOp->getName().getIdentifier(),
56                                            OpPassManager::Nesting::Implicit);
57       (void)parsePassPipeline(pipeline, *pm, llvm::errs());
58     }
59 
60     // Check that running on the parent operation always immediately fails.
61     if (runOnParent) {
62       if (currentOp->getParentOp())
63         if (!failed(runPipeline(*pm, currentOp->getParentOp())))
64           signalPassFailure();
65       return;
66     }
67 
68     if (runOnNestedOp) {
69       llvm::errs() << "Run on nested op\n";
70       currentOp->walk([&](Operation *op) {
71         if (op == currentOp || !op->hasTrait<OpTrait::IsIsolatedFromAbove>() ||
72             op->getName() != currentOp->getName())
73           return;
74         llvm::errs() << "Run on " << *op << "\n";
75         // Run on the current operation
76         if (failed(runPipeline(*pm, op)))
77           signalPassFailure();
78       });
79     } else {
80       // Run on the current operation
81       if (failed(runPipeline(*pm, currentOp)))
82         signalPassFailure();
83     }
84   }
85 
86   std::unique_ptr<OpPassManager> pm;
87 
88   Option<bool> runOnNestedOp{
89       *this, "run-on-nested-operations",
90       llvm::cl::desc("This will apply the pipeline on nested operations under "
91                      "the visited operation.")};
92   Option<bool> runOnParent{
93       *this, "run-on-parent",
94       llvm::cl::desc("This will apply the pipeline on the parent operation if "
95                      "it exist, this is expected to fail.")};
96   Option<std::string> pipeline{
97       *this, "dynamic-pipeline",
98       llvm::cl::desc("The pipeline description that "
99                      "will run on the filtered function.")};
100   ListOption<std::string> opNames{
101       *this, "op-name", llvm::cl::MiscFlags::CommaSeparated,
102       llvm::cl::desc("List of function name to apply the pipeline to")};
103 };
104 } // namespace
105 
106 namespace mlir {
107 namespace test {
108 void registerTestDynamicPipelinePass() {
109   PassRegistration<TestDynamicPipelinePass>(
110       "test-dynamic-pipeline", "Tests the dynamic pipeline feature by applying "
111                                "a pipeline on a selected set of functions");
112 }
113 } // namespace test
114 } // namespace mlir
115