1dc43f785SMehdi Amini //===- Pass.cpp - Pass Management -----------------------------------------===//
2dc43f785SMehdi Amini //
3dc43f785SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4dc43f785SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
5dc43f785SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6dc43f785SMehdi Amini //
7dc43f785SMehdi Amini //===----------------------------------------------------------------------===//
8dc43f785SMehdi Amini
9dc43f785SMehdi Amini #include "Pass.h"
10dc43f785SMehdi Amini
11436c6c9cSStella Laurenzo #include "IRModule.h"
125fef6ce0SStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h"
13dc43f785SMehdi Amini #include "mlir-c/Pass.h"
14dc43f785SMehdi Amini
15dc43f785SMehdi Amini namespace py = pybind11;
16dc43f785SMehdi Amini using namespace mlir;
17dc43f785SMehdi Amini using namespace mlir::python;
18dc43f785SMehdi Amini
19dc43f785SMehdi Amini namespace {
20dc43f785SMehdi Amini
21dc43f785SMehdi Amini /// Owning Wrapper around a PassManager.
22dc43f785SMehdi Amini class PyPassManager {
23dc43f785SMehdi Amini public:
PyPassManager(MlirPassManager passManager)24dc43f785SMehdi Amini PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
PyPassManager(PyPassManager && other)255fef6ce0SStella Laurenzo PyPassManager(PyPassManager &&other) : passManager(other.passManager) {
265fef6ce0SStella Laurenzo other.passManager.ptr = nullptr;
275fef6ce0SStella Laurenzo }
~PyPassManager()285fef6ce0SStella Laurenzo ~PyPassManager() {
295fef6ce0SStella Laurenzo if (!mlirPassManagerIsNull(passManager))
305fef6ce0SStella Laurenzo mlirPassManagerDestroy(passManager);
315fef6ce0SStella Laurenzo }
get()32dc43f785SMehdi Amini MlirPassManager get() { return passManager; }
33dc43f785SMehdi Amini
release()345fef6ce0SStella Laurenzo void release() { passManager.ptr = nullptr; }
getCapsule()355fef6ce0SStella Laurenzo pybind11::object getCapsule() {
365fef6ce0SStella Laurenzo return py::reinterpret_steal<py::object>(
375fef6ce0SStella Laurenzo mlirPythonPassManagerToCapsule(get()));
385fef6ce0SStella Laurenzo }
395fef6ce0SStella Laurenzo
createFromCapsule(pybind11::object capsule)405fef6ce0SStella Laurenzo static pybind11::object createFromCapsule(pybind11::object capsule) {
415fef6ce0SStella Laurenzo MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
425fef6ce0SStella Laurenzo if (mlirPassManagerIsNull(rawPm))
435fef6ce0SStella Laurenzo throw py::error_already_set();
445fef6ce0SStella Laurenzo return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
455fef6ce0SStella Laurenzo }
465fef6ce0SStella Laurenzo
47dc43f785SMehdi Amini private:
48dc43f785SMehdi Amini MlirPassManager passManager;
49dc43f785SMehdi Amini };
50dc43f785SMehdi Amini
51be0a7e9fSMehdi Amini } // namespace
52dc43f785SMehdi Amini
53dc43f785SMehdi Amini /// Create the `mlir.passmanager` here.
populatePassManagerSubmodule(py::module & m)54dc43f785SMehdi Amini void mlir::python::populatePassManagerSubmodule(py::module &m) {
55dc43f785SMehdi Amini //----------------------------------------------------------------------------
56dc43f785SMehdi Amini // Mapping of the top-level PassManager
57dc43f785SMehdi Amini //----------------------------------------------------------------------------
58f05ff4f7SStella Laurenzo py::class_<PyPassManager>(m, "PassManager", py::module_local())
59dc43f785SMehdi Amini .def(py::init<>([](DefaultingPyMlirContext context) {
60dc43f785SMehdi Amini MlirPassManager passManager =
61dc43f785SMehdi Amini mlirPassManagerCreate(context->get());
62dc43f785SMehdi Amini return new PyPassManager(passManager);
63dc43f785SMehdi Amini }),
64dc43f785SMehdi Amini py::arg("context") = py::none(),
65dc43f785SMehdi Amini "Create a new PassManager for the current (or provided) Context.")
665fef6ce0SStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
675fef6ce0SStella Laurenzo &PyPassManager::getCapsule)
685fef6ce0SStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
695fef6ce0SStella Laurenzo .def("_testing_release", &PyPassManager::release,
705fef6ce0SStella Laurenzo "Releases (leaks) the backing pass manager (testing)")
71caa159f0SNicolas Vasilache .def(
72caa159f0SNicolas Vasilache "enable_ir_printing",
73caa159f0SNicolas Vasilache [](PyPassManager &passManager) {
74caa159f0SNicolas Vasilache mlirPassManagerEnableIRPrinting(passManager.get());
75caa159f0SNicolas Vasilache },
76*fb16ed25SAndrzej Warzynski "Enable mlir-print-ir-after-all.")
77caa159f0SNicolas Vasilache .def(
78caa159f0SNicolas Vasilache "enable_verifier",
79caa159f0SNicolas Vasilache [](PyPassManager &passManager, bool enable) {
80caa159f0SNicolas Vasilache mlirPassManagerEnableVerifier(passManager.get(), enable);
81caa159f0SNicolas Vasilache },
82a6e7d024SStella Laurenzo py::arg("enable"), "Enable / disable verify-each.")
83dc43f785SMehdi Amini .def_static(
84dc43f785SMehdi Amini "parse",
85dc43f785SMehdi Amini [](const std::string pipeline, DefaultingPyMlirContext context) {
86dc43f785SMehdi Amini MlirPassManager passManager = mlirPassManagerCreate(context->get());
87dc43f785SMehdi Amini MlirLogicalResult status = mlirParsePassPipeline(
88dc43f785SMehdi Amini mlirPassManagerGetAsOpPassManager(passManager),
89dc43f785SMehdi Amini mlirStringRefCreate(pipeline.data(), pipeline.size()));
90dc43f785SMehdi Amini if (mlirLogicalResultIsFailure(status))
91dc43f785SMehdi Amini throw SetPyError(PyExc_ValueError,
92dc43f785SMehdi Amini llvm::Twine("invalid pass pipeline '") +
93dc43f785SMehdi Amini pipeline + "'.");
94dc43f785SMehdi Amini return new PyPassManager(passManager);
95dc43f785SMehdi Amini },
96dc43f785SMehdi Amini py::arg("pipeline"), py::arg("context") = py::none(),
97dc43f785SMehdi Amini "Parse a textual pass-pipeline and return a top-level PassManager "
98dc43f785SMehdi Amini "that can be applied on a Module. Throw a ValueError if the pipeline "
99dc43f785SMehdi Amini "can't be parsed")
100dc43f785SMehdi Amini .def(
1016cb1c0caSMehdi Amini "run",
1026cb1c0caSMehdi Amini [](PyPassManager &passManager, PyModule &module) {
1036cb1c0caSMehdi Amini MlirLogicalResult status =
1046cb1c0caSMehdi Amini mlirPassManagerRun(passManager.get(), module.get());
1056cb1c0caSMehdi Amini if (mlirLogicalResultIsFailure(status))
1066cb1c0caSMehdi Amini throw SetPyError(PyExc_RuntimeError,
1076cb1c0caSMehdi Amini "Failure while executing pass pipeline.");
1086cb1c0caSMehdi Amini },
109a6e7d024SStella Laurenzo py::arg("module"),
1106cb1c0caSMehdi Amini "Run the pass manager on the provided module, throw a RuntimeError "
1116cb1c0caSMehdi Amini "on failure.")
1126cb1c0caSMehdi Amini .def(
113dc43f785SMehdi Amini "__str__",
114dc43f785SMehdi Amini [](PyPassManager &self) {
115dc43f785SMehdi Amini MlirPassManager passManager = self.get();
116dc43f785SMehdi Amini PyPrintAccumulator printAccum;
117dc43f785SMehdi Amini mlirPrintPassPipeline(
118dc43f785SMehdi Amini mlirPassManagerGetAsOpPassManager(passManager),
119dc43f785SMehdi Amini printAccum.getCallback(), printAccum.getUserData());
120dc43f785SMehdi Amini return printAccum.join();
121dc43f785SMehdi Amini },
122dc43f785SMehdi Amini "Print the textual representation for this PassManager, suitable to "
123dc43f785SMehdi Amini "be passed to `parse` for round-tripping.");
124dc43f785SMehdi Amini }
125