1 //===- TestPassManager.cpp - Test pass manager functionality --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinOps.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Pass/PassManager.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 struct TestModulePass
17     : public PassWrapper<TestModulePass, OperationPass<ModuleOp>> {
18   void runOnOperation() final {}
19   StringRef getArgument() const final { return "test-module-pass"; }
20   StringRef getDescription() const final {
21     return "Test a module pass in the pass manager";
22   }
23 };
24 struct TestFunctionPass
25     : public PassWrapper<TestFunctionPass, OperationPass<FuncOp>> {
26   void runOnOperation() final {}
27   StringRef getArgument() const final { return "test-function-pass"; }
28   StringRef getDescription() const final {
29     return "Test a function pass in the pass manager";
30   }
31 };
32 class TestInterfacePass
33     : public PassWrapper<TestInterfacePass,
34                          InterfacePass<FunctionOpInterface>> {
35   void runOnOperation() final {
36     getOperation()->emitRemark() << "Executing interface pass on operation";
37   }
38   StringRef getArgument() const final { return "test-interface-pass"; }
39   StringRef getDescription() const final {
40     return "Test an interface pass (running on FunctionOpInterface) in the "
41            "pass manager";
42   }
43 };
44 class TestOptionsPass
45     : public PassWrapper<TestOptionsPass, OperationPass<FuncOp>> {
46 public:
47   struct Options : public PassPipelineOptions<Options> {
48     ListOption<int> listOption{*this, "list",
49                                llvm::cl::MiscFlags::CommaSeparated,
50                                llvm::cl::desc("Example list option")};
51     ListOption<std::string> stringListOption{
52         *this, "string-list", llvm::cl::MiscFlags::CommaSeparated,
53         llvm::cl::desc("Example string list option")};
54     Option<std::string> stringOption{*this, "string",
55                                      llvm::cl::desc("Example string option")};
56   };
57   TestOptionsPass() = default;
58   TestOptionsPass(const TestOptionsPass &) {}
59   TestOptionsPass(const Options &options) {
60     listOption = options.listOption;
61     stringOption = options.stringOption;
62     stringListOption = options.stringListOption;
63   }
64 
65   void runOnOperation() final {}
66   StringRef getArgument() const final { return "test-options-pass"; }
67   StringRef getDescription() const final {
68     return "Test options parsing capabilities";
69   }
70 
71   ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
72                              llvm::cl::desc("Example list option")};
73   ListOption<std::string> stringListOption{
74       *this, "string-list", llvm::cl::MiscFlags::CommaSeparated,
75       llvm::cl::desc("Example string list option")};
76   Option<std::string> stringOption{*this, "string",
77                                    llvm::cl::desc("Example string option")};
78 };
79 
80 /// A test pass that always aborts to enable testing the crash recovery
81 /// mechanism of the pass manager.
82 class TestCrashRecoveryPass
83     : public PassWrapper<TestCrashRecoveryPass, OperationPass<>> {
84   void runOnOperation() final { abort(); }
85   StringRef getArgument() const final { return "test-pass-crash"; }
86   StringRef getDescription() const final {
87     return "Test a pass in the pass manager that always crashes";
88   }
89 };
90 
91 /// A test pass that always fails to enable testing the failure recovery
92 /// mechanisms of the pass manager.
93 class TestFailurePass : public PassWrapper<TestFailurePass, OperationPass<>> {
94   void runOnOperation() final { signalPassFailure(); }
95   StringRef getArgument() const final { return "test-pass-failure"; }
96   StringRef getDescription() const final {
97     return "Test a pass in the pass manager that always fails";
98   }
99 };
100 
101 /// A test pass that contains a statistic.
102 struct TestStatisticPass
103     : public PassWrapper<TestStatisticPass, OperationPass<>> {
104   TestStatisticPass() = default;
105   TestStatisticPass(const TestStatisticPass &) {}
106   StringRef getArgument() const final { return "test-stats-pass"; }
107   StringRef getDescription() const final { return "Test pass statistics"; }
108 
109   Statistic opCount{this, "num-ops", "Number of operations counted"};
110 
111   void runOnOperation() final {
112     getOperation()->walk([&](Operation *) { ++opCount; });
113   }
114 };
115 } // namespace
116 
117 static void testNestedPipeline(OpPassManager &pm) {
118   // Nest a module pipeline that contains:
119   /// A module pass.
120   auto &modulePM = pm.nest<ModuleOp>();
121   modulePM.addPass(std::make_unique<TestModulePass>());
122   /// A nested function pass.
123   auto &nestedFunctionPM = modulePM.nest<FuncOp>();
124   nestedFunctionPM.addPass(std::make_unique<TestFunctionPass>());
125 
126   // Nest a function pipeline that contains a single pass.
127   auto &functionPM = pm.nest<FuncOp>();
128   functionPM.addPass(std::make_unique<TestFunctionPass>());
129 }
130 
131 static void testNestedPipelineTextual(OpPassManager &pm) {
132   (void)parsePassPipeline("test-pm-nested-pipeline", pm);
133 }
134 
135 namespace mlir {
136 void registerPassManagerTestPass() {
137   PassRegistration<TestOptionsPass>();
138 
139   PassRegistration<TestModulePass>();
140 
141   PassRegistration<TestFunctionPass>();
142 
143   PassRegistration<TestInterfacePass>();
144 
145   PassRegistration<TestCrashRecoveryPass>();
146   PassRegistration<TestFailurePass>();
147 
148   PassRegistration<TestStatisticPass>();
149 
150   PassPipelineRegistration<>("test-pm-nested-pipeline",
151                              "Test a nested pipeline in the pass manager",
152                              testNestedPipeline);
153   PassPipelineRegistration<>("test-textual-pm-nested-pipeline",
154                              "Test a nested pipeline in the pass manager",
155                              testNestedPipelineTextual);
156   PassPipelineRegistration<>(
157       "test-dump-pipeline",
158       "Dumps the pipeline build so far for debugging purposes",
159       [](OpPassManager &pm) {
160         pm.printAsTextualPipeline(llvm::errs());
161         llvm::errs() << "\n";
162       });
163 
164   PassPipelineRegistration<TestOptionsPass::Options>
165       registerOptionsPassPipeline(
166           "test-options-pass-pipeline",
167           "Parses options using pass pipeline registration",
168           [](OpPassManager &pm, const TestOptionsPass::Options &options) {
169             pm.addPass(std::make_unique<TestOptionsPass>(options));
170           });
171 }
172 } // namespace mlir
173