1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "Pass.h"
10
11 #include "IRModule.h"
12 #include "mlir-c/Bindings/Python/Interop.h"
13 #include "mlir-c/Pass.h"
14
15 namespace py = pybind11;
16 using namespace mlir;
17 using namespace mlir::python;
18
19 namespace {
20
21 /// Owning Wrapper around a PassManager.
22 class PyPassManager {
23 public:
PyPassManager(MlirPassManager passManager)24 PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
PyPassManager(PyPassManager && other)25 PyPassManager(PyPassManager &&other) : passManager(other.passManager) {
26 other.passManager.ptr = nullptr;
27 }
~PyPassManager()28 ~PyPassManager() {
29 if (!mlirPassManagerIsNull(passManager))
30 mlirPassManagerDestroy(passManager);
31 }
get()32 MlirPassManager get() { return passManager; }
33
release()34 void release() { passManager.ptr = nullptr; }
getCapsule()35 pybind11::object getCapsule() {
36 return py::reinterpret_steal<py::object>(
37 mlirPythonPassManagerToCapsule(get()));
38 }
39
createFromCapsule(pybind11::object capsule)40 static pybind11::object createFromCapsule(pybind11::object capsule) {
41 MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
42 if (mlirPassManagerIsNull(rawPm))
43 throw py::error_already_set();
44 return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
45 }
46
47 private:
48 MlirPassManager passManager;
49 };
50
51 } // namespace
52
53 /// Create the `mlir.passmanager` here.
populatePassManagerSubmodule(py::module & m)54 void mlir::python::populatePassManagerSubmodule(py::module &m) {
55 //----------------------------------------------------------------------------
56 // Mapping of the top-level PassManager
57 //----------------------------------------------------------------------------
58 py::class_<PyPassManager>(m, "PassManager", py::module_local())
59 .def(py::init<>([](DefaultingPyMlirContext context) {
60 MlirPassManager passManager =
61 mlirPassManagerCreate(context->get());
62 return new PyPassManager(passManager);
63 }),
64 py::arg("context") = py::none(),
65 "Create a new PassManager for the current (or provided) Context.")
66 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
67 &PyPassManager::getCapsule)
68 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
69 .def("_testing_release", &PyPassManager::release,
70 "Releases (leaks) the backing pass manager (testing)")
71 .def(
72 "enable_ir_printing",
73 [](PyPassManager &passManager) {
74 mlirPassManagerEnableIRPrinting(passManager.get());
75 },
76 "Enable mlir-print-ir-after-all.")
77 .def(
78 "enable_verifier",
79 [](PyPassManager &passManager, bool enable) {
80 mlirPassManagerEnableVerifier(passManager.get(), enable);
81 },
82 py::arg("enable"), "Enable / disable verify-each.")
83 .def_static(
84 "parse",
85 [](const std::string pipeline, DefaultingPyMlirContext context) {
86 MlirPassManager passManager = mlirPassManagerCreate(context->get());
87 MlirLogicalResult status = mlirParsePassPipeline(
88 mlirPassManagerGetAsOpPassManager(passManager),
89 mlirStringRefCreate(pipeline.data(), pipeline.size()));
90 if (mlirLogicalResultIsFailure(status))
91 throw SetPyError(PyExc_ValueError,
92 llvm::Twine("invalid pass pipeline '") +
93 pipeline + "'.");
94 return new PyPassManager(passManager);
95 },
96 py::arg("pipeline"), py::arg("context") = py::none(),
97 "Parse a textual pass-pipeline and return a top-level PassManager "
98 "that can be applied on a Module. Throw a ValueError if the pipeline "
99 "can't be parsed")
100 .def(
101 "run",
102 [](PyPassManager &passManager, PyModule &module) {
103 MlirLogicalResult status =
104 mlirPassManagerRun(passManager.get(), module.get());
105 if (mlirLogicalResultIsFailure(status))
106 throw SetPyError(PyExc_RuntimeError,
107 "Failure while executing pass pipeline.");
108 },
109 py::arg("module"),
110 "Run the pass manager on the provided module, throw a RuntimeError "
111 "on failure.")
112 .def(
113 "__str__",
114 [](PyPassManager &self) {
115 MlirPassManager passManager = self.get();
116 PyPrintAccumulator printAccum;
117 mlirPrintPassPipeline(
118 mlirPassManagerGetAsOpPassManager(passManager),
119 printAccum.getCallback(), printAccum.getUserData());
120 return printAccum.join();
121 },
122 "Print the textual representation for this PassManager, suitable to "
123 "be passed to `parse` for round-tripping.");
124 }
125