1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Pass.h"
10 
11 #include "IRModules.h"
12 #include "mlir-c/Pass.h"
13 
14 namespace py = pybind11;
15 using namespace mlir;
16 using namespace mlir::python;
17 
18 namespace {
19 
20 /// Owning Wrapper around a PassManager.
21 class PyPassManager {
22 public:
23   PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
24   ~PyPassManager() { mlirPassManagerDestroy(passManager); }
25   MlirPassManager get() { return passManager; }
26 
27 private:
28   MlirPassManager passManager;
29 };
30 
31 } // anonymous namespace
32 
33 /// Create the `mlir.passmanager` here.
34 void mlir::python::populatePassManagerSubmodule(py::module &m) {
35   //----------------------------------------------------------------------------
36   // Mapping of the top-level PassManager
37   //----------------------------------------------------------------------------
38   py::class_<PyPassManager>(m, "PassManager")
39       .def(py::init<>([](DefaultingPyMlirContext context) {
40              MlirPassManager passManager =
41                  mlirPassManagerCreate(context->get());
42              return new PyPassManager(passManager);
43            }),
44            py::arg("context") = py::none(),
45            "Create a new PassManager for the current (or provided) Context.")
46       .def_static(
47           "parse",
48           [](const std::string pipeline, DefaultingPyMlirContext context) {
49             MlirPassManager passManager = mlirPassManagerCreate(context->get());
50             MlirLogicalResult status = mlirParsePassPipeline(
51                 mlirPassManagerGetAsOpPassManager(passManager),
52                 mlirStringRefCreate(pipeline.data(), pipeline.size()));
53             if (mlirLogicalResultIsFailure(status))
54               throw SetPyError(PyExc_ValueError,
55                                llvm::Twine("invalid pass pipeline '") +
56                                    pipeline + "'.");
57             return new PyPassManager(passManager);
58           },
59           py::arg("pipeline"), py::arg("context") = py::none(),
60           "Parse a textual pass-pipeline and return a top-level PassManager "
61           "that can be applied on a Module. Throw a ValueError if the pipeline "
62           "can't be parsed")
63       .def(
64           "run",
65           [](PyPassManager &passManager, PyModule &module) {
66             MlirLogicalResult status =
67                 mlirPassManagerRun(passManager.get(), module.get());
68             if (mlirLogicalResultIsFailure(status))
69               throw SetPyError(PyExc_RuntimeError,
70                                "Failure while executing pass pipeline.");
71           },
72           "Run the pass manager on the provided module, throw a RuntimeError "
73           "on failure.")
74       .def(
75           "__str__",
76           [](PyPassManager &self) {
77             MlirPassManager passManager = self.get();
78             PyPrintAccumulator printAccum;
79             mlirPrintPassPipeline(
80                 mlirPassManagerGetAsOpPassManager(passManager),
81                 printAccum.getCallback(), printAccum.getUserData());
82             return printAccum.join();
83           },
84           "Print the textual representation for this PassManager, suitable to "
85           "be passed to `parse` for round-tripping.");
86 }
87