1dc43f785SMehdi Amini //===- Pass.cpp - Pass Management -----------------------------------------===//
2dc43f785SMehdi Amini //
3dc43f785SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4dc43f785SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
5dc43f785SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6dc43f785SMehdi Amini //
7dc43f785SMehdi Amini //===----------------------------------------------------------------------===//
8dc43f785SMehdi Amini 
9dc43f785SMehdi Amini #include "Pass.h"
10dc43f785SMehdi Amini 
11436c6c9cSStella Laurenzo #include "IRModule.h"
125fef6ce0SStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h"
13dc43f785SMehdi Amini #include "mlir-c/Pass.h"
14dc43f785SMehdi Amini 
15dc43f785SMehdi Amini namespace py = pybind11;
16dc43f785SMehdi Amini using namespace mlir;
17dc43f785SMehdi Amini using namespace mlir::python;
18dc43f785SMehdi Amini 
19dc43f785SMehdi Amini namespace {
20dc43f785SMehdi Amini 
21dc43f785SMehdi Amini /// Owning Wrapper around a PassManager.
22dc43f785SMehdi Amini class PyPassManager {
23dc43f785SMehdi Amini public:
PyPassManager(MlirPassManager passManager)24dc43f785SMehdi Amini   PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
PyPassManager(PyPassManager && other)255fef6ce0SStella Laurenzo   PyPassManager(PyPassManager &&other) : passManager(other.passManager) {
265fef6ce0SStella Laurenzo     other.passManager.ptr = nullptr;
275fef6ce0SStella Laurenzo   }
~PyPassManager()285fef6ce0SStella Laurenzo   ~PyPassManager() {
295fef6ce0SStella Laurenzo     if (!mlirPassManagerIsNull(passManager))
305fef6ce0SStella Laurenzo       mlirPassManagerDestroy(passManager);
315fef6ce0SStella Laurenzo   }
get()32dc43f785SMehdi Amini   MlirPassManager get() { return passManager; }
33dc43f785SMehdi Amini 
release()345fef6ce0SStella Laurenzo   void release() { passManager.ptr = nullptr; }
getCapsule()355fef6ce0SStella Laurenzo   pybind11::object getCapsule() {
365fef6ce0SStella Laurenzo     return py::reinterpret_steal<py::object>(
375fef6ce0SStella Laurenzo         mlirPythonPassManagerToCapsule(get()));
385fef6ce0SStella Laurenzo   }
395fef6ce0SStella Laurenzo 
createFromCapsule(pybind11::object capsule)405fef6ce0SStella Laurenzo   static pybind11::object createFromCapsule(pybind11::object capsule) {
415fef6ce0SStella Laurenzo     MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
425fef6ce0SStella Laurenzo     if (mlirPassManagerIsNull(rawPm))
435fef6ce0SStella Laurenzo       throw py::error_already_set();
445fef6ce0SStella Laurenzo     return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
455fef6ce0SStella Laurenzo   }
465fef6ce0SStella Laurenzo 
47dc43f785SMehdi Amini private:
48dc43f785SMehdi Amini   MlirPassManager passManager;
49dc43f785SMehdi Amini };
50dc43f785SMehdi Amini 
51be0a7e9fSMehdi Amini } // namespace
52dc43f785SMehdi Amini 
53dc43f785SMehdi Amini /// Create the `mlir.passmanager` here.
populatePassManagerSubmodule(py::module & m)54dc43f785SMehdi Amini void mlir::python::populatePassManagerSubmodule(py::module &m) {
55dc43f785SMehdi Amini   //----------------------------------------------------------------------------
56dc43f785SMehdi Amini   // Mapping of the top-level PassManager
57dc43f785SMehdi Amini   //----------------------------------------------------------------------------
58f05ff4f7SStella Laurenzo   py::class_<PyPassManager>(m, "PassManager", py::module_local())
59dc43f785SMehdi Amini       .def(py::init<>([](DefaultingPyMlirContext context) {
60dc43f785SMehdi Amini              MlirPassManager passManager =
61dc43f785SMehdi Amini                  mlirPassManagerCreate(context->get());
62dc43f785SMehdi Amini              return new PyPassManager(passManager);
63dc43f785SMehdi Amini            }),
64dc43f785SMehdi Amini            py::arg("context") = py::none(),
65dc43f785SMehdi Amini            "Create a new PassManager for the current (or provided) Context.")
665fef6ce0SStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
675fef6ce0SStella Laurenzo                              &PyPassManager::getCapsule)
685fef6ce0SStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
695fef6ce0SStella Laurenzo       .def("_testing_release", &PyPassManager::release,
705fef6ce0SStella Laurenzo            "Releases (leaks) the backing pass manager (testing)")
71caa159f0SNicolas Vasilache       .def(
72caa159f0SNicolas Vasilache           "enable_ir_printing",
73caa159f0SNicolas Vasilache           [](PyPassManager &passManager) {
74caa159f0SNicolas Vasilache             mlirPassManagerEnableIRPrinting(passManager.get());
75caa159f0SNicolas Vasilache           },
76*fb16ed25SAndrzej Warzynski           "Enable mlir-print-ir-after-all.")
77caa159f0SNicolas Vasilache       .def(
78caa159f0SNicolas Vasilache           "enable_verifier",
79caa159f0SNicolas Vasilache           [](PyPassManager &passManager, bool enable) {
80caa159f0SNicolas Vasilache             mlirPassManagerEnableVerifier(passManager.get(), enable);
81caa159f0SNicolas Vasilache           },
82a6e7d024SStella Laurenzo           py::arg("enable"), "Enable / disable verify-each.")
83dc43f785SMehdi Amini       .def_static(
84dc43f785SMehdi Amini           "parse",
85dc43f785SMehdi Amini           [](const std::string pipeline, DefaultingPyMlirContext context) {
86dc43f785SMehdi Amini             MlirPassManager passManager = mlirPassManagerCreate(context->get());
87dc43f785SMehdi Amini             MlirLogicalResult status = mlirParsePassPipeline(
88dc43f785SMehdi Amini                 mlirPassManagerGetAsOpPassManager(passManager),
89dc43f785SMehdi Amini                 mlirStringRefCreate(pipeline.data(), pipeline.size()));
90dc43f785SMehdi Amini             if (mlirLogicalResultIsFailure(status))
91dc43f785SMehdi Amini               throw SetPyError(PyExc_ValueError,
92dc43f785SMehdi Amini                                llvm::Twine("invalid pass pipeline '") +
93dc43f785SMehdi Amini                                    pipeline + "'.");
94dc43f785SMehdi Amini             return new PyPassManager(passManager);
95dc43f785SMehdi Amini           },
96dc43f785SMehdi Amini           py::arg("pipeline"), py::arg("context") = py::none(),
97dc43f785SMehdi Amini           "Parse a textual pass-pipeline and return a top-level PassManager "
98dc43f785SMehdi Amini           "that can be applied on a Module. Throw a ValueError if the pipeline "
99dc43f785SMehdi Amini           "can't be parsed")
100dc43f785SMehdi Amini       .def(
1016cb1c0caSMehdi Amini           "run",
1026cb1c0caSMehdi Amini           [](PyPassManager &passManager, PyModule &module) {
1036cb1c0caSMehdi Amini             MlirLogicalResult status =
1046cb1c0caSMehdi Amini                 mlirPassManagerRun(passManager.get(), module.get());
1056cb1c0caSMehdi Amini             if (mlirLogicalResultIsFailure(status))
1066cb1c0caSMehdi Amini               throw SetPyError(PyExc_RuntimeError,
1076cb1c0caSMehdi Amini                                "Failure while executing pass pipeline.");
1086cb1c0caSMehdi Amini           },
109a6e7d024SStella Laurenzo           py::arg("module"),
1106cb1c0caSMehdi Amini           "Run the pass manager on the provided module, throw a RuntimeError "
1116cb1c0caSMehdi Amini           "on failure.")
1126cb1c0caSMehdi Amini       .def(
113dc43f785SMehdi Amini           "__str__",
114dc43f785SMehdi Amini           [](PyPassManager &self) {
115dc43f785SMehdi Amini             MlirPassManager passManager = self.get();
116dc43f785SMehdi Amini             PyPrintAccumulator printAccum;
117dc43f785SMehdi Amini             mlirPrintPassPipeline(
118dc43f785SMehdi Amini                 mlirPassManagerGetAsOpPassManager(passManager),
119dc43f785SMehdi Amini                 printAccum.getCallback(), printAccum.getUserData());
120dc43f785SMehdi Amini             return printAccum.join();
121dc43f785SMehdi Amini           },
122dc43f785SMehdi Amini           "Print the textual representation for this PassManager, suitable to "
123dc43f785SMehdi Amini           "be passed to `parse` for round-tripping.");
124dc43f785SMehdi Amini }
125