1 //===------ TestDynamicPipeline.cpp --- dynamic pipeline test pass --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test the dynamic pipeline feature.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Pass/PassManager.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 
20 class TestDynamicPipelinePass
21     : public PassWrapper<TestDynamicPipelinePass, OperationPass<>> {
22 public:
23   StringRef getArgument() const final { return "test-dynamic-pipeline"; }
24   StringRef getDescription() const final {
25     return "Tests the dynamic pipeline feature by applying "
26            "a pipeline on a selected set of functions";
27   }
28   void getDependentDialects(DialectRegistry &registry) const override {
29     OpPassManager pm(ModuleOp::getOperationName(),
30                      OpPassManager::Nesting::Implicit);
31     (void)parsePassPipeline(pipeline, pm, llvm::errs());
32     pm.getDependentDialects(registry);
33   }
34 
35   TestDynamicPipelinePass(){};
36   TestDynamicPipelinePass(const TestDynamicPipelinePass &) {}
37 
38   void runOnOperation() override {
39     Operation *currentOp = getOperation();
40 
41     llvm::errs() << "Dynamic execute '" << pipeline << "' on "
42                  << currentOp->getName() << "\n";
43     if (pipeline.empty()) {
44       llvm::errs() << "Empty pipeline\n";
45       return;
46     }
47     auto symbolOp = dyn_cast<SymbolOpInterface>(currentOp);
48     if (!symbolOp) {
49       currentOp->emitWarning()
50           << "Ignoring because not implementing SymbolOpInterface\n";
51       return;
52     }
53 
54     auto opName = symbolOp.getName();
55     if (!opNames.empty() && !llvm::is_contained(opNames, opName)) {
56       llvm::errs() << "dynamic-pipeline skip op name: " << opName << "\n";
57       return;
58     }
59     if (!pm) {
60       pm = std::make_unique<OpPassManager>(currentOp->getName().getIdentifier(),
61                                            OpPassManager::Nesting::Implicit);
62       (void)parsePassPipeline(pipeline, *pm, llvm::errs());
63     }
64 
65     // Check that running on the parent operation always immediately fails.
66     if (runOnParent) {
67       if (currentOp->getParentOp())
68         if (!failed(runPipeline(*pm, currentOp->getParentOp())))
69           signalPassFailure();
70       return;
71     }
72 
73     if (runOnNestedOp) {
74       llvm::errs() << "Run on nested op\n";
75       currentOp->walk([&](Operation *op) {
76         if (op == currentOp || !op->hasTrait<OpTrait::IsIsolatedFromAbove>() ||
77             op->getName() != currentOp->getName())
78           return;
79         llvm::errs() << "Run on " << *op << "\n";
80         // Run on the current operation
81         if (failed(runPipeline(*pm, op)))
82           signalPassFailure();
83       });
84     } else {
85       // Run on the current operation
86       if (failed(runPipeline(*pm, currentOp)))
87         signalPassFailure();
88     }
89   }
90 
91   std::unique_ptr<OpPassManager> pm;
92 
93   Option<bool> runOnNestedOp{
94       *this, "run-on-nested-operations",
95       llvm::cl::desc("This will apply the pipeline on nested operations under "
96                      "the visited operation.")};
97   Option<bool> runOnParent{
98       *this, "run-on-parent",
99       llvm::cl::desc("This will apply the pipeline on the parent operation if "
100                      "it exist, this is expected to fail.")};
101   Option<std::string> pipeline{
102       *this, "dynamic-pipeline",
103       llvm::cl::desc("The pipeline description that "
104                      "will run on the filtered function.")};
105   ListOption<std::string> opNames{
106       *this, "op-name", llvm::cl::MiscFlags::CommaSeparated,
107       llvm::cl::desc("List of function name to apply the pipeline to")};
108 };
109 } // namespace
110 
111 namespace mlir {
112 namespace test {
113 void registerTestDynamicPipelinePass() {
114   PassRegistration<TestDynamicPipelinePass>();
115 }
116 } // namespace test
117 } // namespace mlir
118