1 //===- Pass.cpp - Pass Management -----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Pass.h" 10 11 #include "IRModules.h" 12 #include "mlir-c/Pass.h" 13 14 namespace py = pybind11; 15 using namespace mlir; 16 using namespace mlir::python; 17 18 namespace { 19 20 /// Owning Wrapper around a PassManager. 21 class PyPassManager { 22 public: 23 PyPassManager(MlirPassManager passManager) : passManager(passManager) {} 24 ~PyPassManager() { mlirPassManagerDestroy(passManager); } 25 MlirPassManager get() { return passManager; } 26 27 private: 28 MlirPassManager passManager; 29 }; 30 31 } // anonymous namespace 32 33 /// Create the `mlir.passmanager` here. 34 void mlir::python::populatePassManagerSubmodule(py::module &m) { 35 //---------------------------------------------------------------------------- 36 // Mapping of the top-level PassManager 37 //---------------------------------------------------------------------------- 38 py::class_<PyPassManager>(m, "PassManager") 39 .def(py::init<>([](DefaultingPyMlirContext context) { 40 MlirPassManager passManager = 41 mlirPassManagerCreate(context->get()); 42 return new PyPassManager(passManager); 43 }), 44 py::arg("context") = py::none(), 45 "Create a new PassManager for the current (or provided) Context.") 46 .def_static( 47 "parse", 48 [](const std::string pipeline, DefaultingPyMlirContext context) { 49 MlirPassManager passManager = mlirPassManagerCreate(context->get()); 50 MlirLogicalResult status = mlirParsePassPipeline( 51 mlirPassManagerGetAsOpPassManager(passManager), 52 mlirStringRefCreate(pipeline.data(), pipeline.size())); 53 if (mlirLogicalResultIsFailure(status)) 54 throw SetPyError(PyExc_ValueError, 55 llvm::Twine("invalid pass pipeline '") + 56 pipeline + "'."); 57 return new PyPassManager(passManager); 58 }, 59 py::arg("pipeline"), py::arg("context") = py::none(), 60 "Parse a textual pass-pipeline and return a top-level PassManager " 61 "that can be applied on a Module. Throw a ValueError if the pipeline " 62 "can't be parsed") 63 .def( 64 "run", 65 [](PyPassManager &passManager, PyModule &module) { 66 MlirLogicalResult status = 67 mlirPassManagerRun(passManager.get(), module.get()); 68 if (mlirLogicalResultIsFailure(status)) 69 throw SetPyError(PyExc_RuntimeError, 70 "Failure while executing pass pipeline."); 71 }, 72 "Run the pass manager on the provided module, throw a RuntimeError " 73 "on failure.") 74 .def( 75 "__str__", 76 [](PyPassManager &self) { 77 MlirPassManager passManager = self.get(); 78 PyPrintAccumulator printAccum; 79 mlirPrintPassPipeline( 80 mlirPassManagerGetAsOpPassManager(passManager), 81 printAccum.getCallback(), printAccum.getUserData()); 82 return printAccum.join(); 83 }, 84 "Print the textual representation for this PassManager, suitable to " 85 "be passed to `parse` for round-tripping."); 86 } 87