1 //===- TestPassManager.cpp - Test pass manager functionality --------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Function.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Pass/PassManager.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 struct TestModulePass : public ModulePass<TestModulePass> {
17   void runOnModule() final {}
18 };
19 struct TestFunctionPass : public FunctionPass<TestFunctionPass> {
20   void runOnFunction() final {}
21 };
22 class TestOptionsPass : public FunctionPass<TestOptionsPass> {
23 public:
24   struct Options : public PassPipelineOptions<Options> {
25     ListOption<int> listOption{*this, "list",
26                                llvm::cl::MiscFlags::CommaSeparated,
27                                llvm::cl::desc("Example list option")};
28     ListOption<std::string> stringListOption{
29         *this, "string-list", llvm::cl::MiscFlags::CommaSeparated,
30         llvm::cl::desc("Example string list option")};
31     Option<std::string> stringOption{*this, "string",
32                                      llvm::cl::desc("Example string option")};
33   };
34   TestOptionsPass() = default;
35   TestOptionsPass(const TestOptionsPass &) {}
36   TestOptionsPass(const Options &options) {
37     listOption->assign(options.listOption.begin(), options.listOption.end());
38     stringOption.setValue(options.stringOption);
39     stringListOption->assign(options.stringListOption.begin(),
40                              options.stringListOption.end());
41   }
42 
43   void runOnFunction() final {}
44 
45   ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
46                              llvm::cl::desc("Example list option")};
47   ListOption<std::string> stringListOption{
48       *this, "string-list", llvm::cl::MiscFlags::CommaSeparated,
49       llvm::cl::desc("Example string list option")};
50   Option<std::string> stringOption{*this, "string",
51                                    llvm::cl::desc("Example string option")};
52 };
53 
54 /// A test pass that always aborts to enable testing the crash recovery
55 /// mechanism of the pass manager.
56 class TestCrashRecoveryPass : public OperationPass<TestCrashRecoveryPass> {
57   void runOnOperation() final { abort(); }
58 };
59 
60 /// A test pass that contains a statistic.
61 struct TestStatisticPass : public OperationPass<TestStatisticPass> {
62   TestStatisticPass() = default;
63   TestStatisticPass(const TestStatisticPass &) {}
64 
65   Statistic opCount{this, "num-ops", "Number of operations counted"};
66 
67   void runOnOperation() final {
68     getOperation()->walk([&](Operation *) { ++opCount; });
69   }
70 };
71 } // end anonymous namespace
72 
73 static void testNestedPipeline(OpPassManager &pm) {
74   // Nest a module pipeline that contains:
75   /// A module pass.
76   auto &modulePM = pm.nest<ModuleOp>();
77   modulePM.addPass(std::make_unique<TestModulePass>());
78   /// A nested function pass.
79   auto &nestedFunctionPM = modulePM.nest<FuncOp>();
80   nestedFunctionPM.addPass(std::make_unique<TestFunctionPass>());
81 
82   // Nest a function pipeline that contains a single pass.
83   auto &functionPM = pm.nest<FuncOp>();
84   functionPM.addPass(std::make_unique<TestFunctionPass>());
85 }
86 
87 static void testNestedPipelineTextual(OpPassManager &pm) {
88   (void)parsePassPipeline("test-pm-nested-pipeline", pm);
89 }
90 
91 static PassRegistration<TestOptionsPass>
92     reg("test-options-pass", "Test options parsing capabilities");
93 
94 static PassRegistration<TestModulePass>
95     unusedMP("test-module-pass", "Test a module pass in the pass manager");
96 static PassRegistration<TestFunctionPass>
97     unusedFP("test-function-pass", "Test a function pass in the pass manager");
98 
99 static PassRegistration<TestCrashRecoveryPass>
100     unusedCrashP("test-pass-crash",
101                  "Test a pass in the pass manager that always crashes");
102 
103 static PassRegistration<TestStatisticPass> unusedStatP("test-stats-pass",
104                                                        "Test pass statistics");
105 
106 static PassPipelineRegistration<>
107     unused("test-pm-nested-pipeline",
108            "Test a nested pipeline in the pass manager", testNestedPipeline);
109 static PassPipelineRegistration<>
110     unusedTextual("test-textual-pm-nested-pipeline",
111                   "Test a nested pipeline in the pass manager",
112                   testNestedPipelineTextual);
113 static PassPipelineRegistration<>
114     unusedDump("test-dump-pipeline",
115                "Dumps the pipeline build so far for debugging purposes",
116                [](OpPassManager &pm) {
117                  pm.printAsTextualPipeline(llvm::errs());
118                  llvm::errs() << "\n";
119                });
120 
121 static PassPipelineRegistration<TestOptionsPass::Options>
122     registerOptionsPassPipeline(
123         "test-options-pass-pipeline",
124         "Parses options using pass pipeline registration",
125         [](OpPassManager &pm, const TestOptionsPass::Options &options) {
126           pm.addPass(std::make_unique<TestOptionsPass>(options));
127         });
128