1 //===- TestPassManager.cpp - Test pass manager functionality --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinOps.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Pass/PassManager.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 struct TestModulePass
17     : public PassWrapper<TestModulePass, OperationPass<ModuleOp>> {
18   void runOnOperation() final {}
19   StringRef getArgument() const final { return "test-module-pass"; }
20   StringRef getDescription() const final {
21     return "Test a module pass in the pass manager";
22   }
23 };
24 struct TestFunctionPass
25     : public PassWrapper<TestFunctionPass, OperationPass<FuncOp>> {
26   void runOnOperation() final {}
27   StringRef getArgument() const final { return "test-function-pass"; }
28   StringRef getDescription() const final {
29     return "Test a function pass in the pass manager";
30   }
31 };
32 class TestOptionsPass
33     : public PassWrapper<TestOptionsPass, OperationPass<FuncOp>> {
34 public:
35   struct Options : public PassPipelineOptions<Options> {
36     ListOption<int> listOption{*this, "list",
37                                llvm::cl::MiscFlags::CommaSeparated,
38                                llvm::cl::desc("Example list option")};
39     ListOption<std::string> stringListOption{
40         *this, "string-list", llvm::cl::MiscFlags::CommaSeparated,
41         llvm::cl::desc("Example string list option")};
42     Option<std::string> stringOption{*this, "string",
43                                      llvm::cl::desc("Example string option")};
44   };
45   TestOptionsPass() = default;
46   TestOptionsPass(const TestOptionsPass &) {}
47   TestOptionsPass(const Options &options) {
48     listOption = options.listOption;
49     stringOption = options.stringOption;
50     stringListOption = options.stringListOption;
51   }
52 
53   void runOnOperation() final {}
54   StringRef getArgument() const final { return "test-options-pass"; }
55   StringRef getDescription() const final {
56     return "Test options parsing capabilities";
57   }
58 
59   ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
60                              llvm::cl::desc("Example list option")};
61   ListOption<std::string> stringListOption{
62       *this, "string-list", llvm::cl::MiscFlags::CommaSeparated,
63       llvm::cl::desc("Example string list option")};
64   Option<std::string> stringOption{*this, "string",
65                                    llvm::cl::desc("Example string option")};
66 };
67 
68 /// A test pass that always aborts to enable testing the crash recovery
69 /// mechanism of the pass manager.
70 class TestCrashRecoveryPass
71     : public PassWrapper<TestCrashRecoveryPass, OperationPass<>> {
72   void runOnOperation() final { abort(); }
73   StringRef getArgument() const final { return "test-pass-crash"; }
74   StringRef getDescription() const final {
75     return "Test a pass in the pass manager that always crashes";
76   }
77 };
78 
79 /// A test pass that always fails to enable testing the failure recovery
80 /// mechanisms of the pass manager.
81 class TestFailurePass : public PassWrapper<TestFailurePass, OperationPass<>> {
82   void runOnOperation() final { signalPassFailure(); }
83   StringRef getArgument() const final { return "test-pass-failure"; }
84   StringRef getDescription() const final {
85     return "Test a pass in the pass manager that always fails";
86   }
87 };
88 
89 /// A test pass that contains a statistic.
90 struct TestStatisticPass
91     : public PassWrapper<TestStatisticPass, OperationPass<>> {
92   TestStatisticPass() = default;
93   TestStatisticPass(const TestStatisticPass &) {}
94   StringRef getArgument() const final { return "test-stats-pass"; }
95   StringRef getDescription() const final { return "Test pass statistics"; }
96 
97   Statistic opCount{this, "num-ops", "Number of operations counted"};
98 
99   void runOnOperation() final {
100     getOperation()->walk([&](Operation *) { ++opCount; });
101   }
102 };
103 } // namespace
104 
105 static void testNestedPipeline(OpPassManager &pm) {
106   // Nest a module pipeline that contains:
107   /// A module pass.
108   auto &modulePM = pm.nest<ModuleOp>();
109   modulePM.addPass(std::make_unique<TestModulePass>());
110   /// A nested function pass.
111   auto &nestedFunctionPM = modulePM.nest<FuncOp>();
112   nestedFunctionPM.addPass(std::make_unique<TestFunctionPass>());
113 
114   // Nest a function pipeline that contains a single pass.
115   auto &functionPM = pm.nest<FuncOp>();
116   functionPM.addPass(std::make_unique<TestFunctionPass>());
117 }
118 
119 static void testNestedPipelineTextual(OpPassManager &pm) {
120   (void)parsePassPipeline("test-pm-nested-pipeline", pm);
121 }
122 
123 namespace mlir {
124 void registerPassManagerTestPass() {
125   PassRegistration<TestOptionsPass>();
126 
127   PassRegistration<TestModulePass>();
128 
129   PassRegistration<TestFunctionPass>();
130 
131   PassRegistration<TestCrashRecoveryPass>();
132   PassRegistration<TestFailurePass>();
133 
134   PassRegistration<TestStatisticPass>();
135 
136   PassPipelineRegistration<>("test-pm-nested-pipeline",
137                              "Test a nested pipeline in the pass manager",
138                              testNestedPipeline);
139   PassPipelineRegistration<>("test-textual-pm-nested-pipeline",
140                              "Test a nested pipeline in the pass manager",
141                              testNestedPipelineTextual);
142   PassPipelineRegistration<>(
143       "test-dump-pipeline",
144       "Dumps the pipeline build so far for debugging purposes",
145       [](OpPassManager &pm) {
146         pm.printAsTextualPipeline(llvm::errs());
147         llvm::errs() << "\n";
148       });
149 
150   PassPipelineRegistration<TestOptionsPass::Options>
151       registerOptionsPassPipeline(
152           "test-options-pass-pipeline",
153           "Parses options using pass pipeline registration",
154           [](OpPassManager &pm, const TestOptionsPass::Options &options) {
155             pm.addPass(std::make_unique<TestOptionsPass>(options));
156           });
157 }
158 } // namespace mlir
159