1 //===------ TestDynamicPipeline.cpp --- dynamic pipeline test pass --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test the dynamic pipeline feature.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Pass/PassManager.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 
20 class TestDynamicPipelinePass
21     : public PassWrapper<TestDynamicPipelinePass, OperationPass<>> {
22 public:
23   StringRef getArgument() const final { return "test-dynamic-pipeline"; }
24   StringRef getDescription() const final {
25     return "Tests the dynamic pipeline feature by applying "
26            "a pipeline on a selected set of functions";
27   }
28   void getDependentDialects(DialectRegistry &registry) const override {
29     OpPassManager pm(ModuleOp::getOperationName(),
30                      OpPassManager::Nesting::Implicit);
31     (void)parsePassPipeline(pipeline, pm, llvm::errs());
32     pm.getDependentDialects(registry);
33   }
34 
35   TestDynamicPipelinePass(){};
36   TestDynamicPipelinePass(const TestDynamicPipelinePass &) {}
37 
38   void runOnOperation() override {
39     Operation *currentOp = getOperation();
40 
41     llvm::errs() << "Dynamic execute '" << pipeline << "' on "
42                  << currentOp->getName() << "\n";
43     if (pipeline.empty()) {
44       llvm::errs() << "Empty pipeline\n";
45       return;
46     }
47     auto symbolOp = dyn_cast<SymbolOpInterface>(currentOp);
48     if (!symbolOp) {
49       currentOp->emitWarning()
50           << "Ignoring because not implementing SymbolOpInterface\n";
51       return;
52     }
53 
54     auto opName = symbolOp.getName();
55     if (!opNames.empty() && !llvm::is_contained(opNames, opName)) {
56       llvm::errs() << "dynamic-pipeline skip op name: " << opName << "\n";
57       return;
58     }
59     OpPassManager pm(currentOp->getName().getIdentifier(),
60                      OpPassManager::Nesting::Implicit);
61     (void)parsePassPipeline(pipeline, pm, llvm::errs());
62 
63     // Check that running on the parent operation always immediately fails.
64     if (runOnParent) {
65       if (currentOp->getParentOp())
66         if (!failed(runPipeline(pm, currentOp->getParentOp())))
67           signalPassFailure();
68       return;
69     }
70 
71     if (runOnNestedOp) {
72       llvm::errs() << "Run on nested op\n";
73       currentOp->walk([&](Operation *op) {
74         if (op == currentOp || !op->hasTrait<OpTrait::IsIsolatedFromAbove>() ||
75             op->getName() != currentOp->getName())
76           return;
77         llvm::errs() << "Run on " << *op << "\n";
78         // Run on the current operation
79         if (failed(runPipeline(pm, op)))
80           signalPassFailure();
81       });
82     } else {
83       // Run on the current operation
84       if (failed(runPipeline(pm, currentOp)))
85         signalPassFailure();
86     }
87   }
88 
89   Option<bool> runOnNestedOp{
90       *this, "run-on-nested-operations",
91       llvm::cl::desc("This will apply the pipeline on nested operations under "
92                      "the visited operation.")};
93   Option<bool> runOnParent{
94       *this, "run-on-parent",
95       llvm::cl::desc("This will apply the pipeline on the parent operation if "
96                      "it exist, this is expected to fail.")};
97   Option<std::string> pipeline{
98       *this, "dynamic-pipeline",
99       llvm::cl::desc("The pipeline description that "
100                      "will run on the filtered function.")};
101   ListOption<std::string> opNames{
102       *this, "op-name", llvm::cl::MiscFlags::CommaSeparated,
103       llvm::cl::desc("List of function name to apply the pipeline to")};
104 };
105 } // namespace
106 
107 namespace mlir {
108 namespace test {
109 void registerTestDynamicPipelinePass() {
110   PassRegistration<TestDynamicPipelinePass>();
111 }
112 } // namespace test
113 } // namespace mlir
114