1722475a3SStella Laurenzo //===- MainModule.cpp - Main pybind module --------------------------------===// 2722475a3SStella Laurenzo // 3722475a3SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4722475a3SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5722475a3SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6722475a3SStella Laurenzo // 7722475a3SStella Laurenzo //===----------------------------------------------------------------------===// 8722475a3SStella Laurenzo 9722475a3SStella Laurenzo #include <tuple> 10722475a3SStella Laurenzo 11*013b9322SStella Laurenzo #include "PybindUtils.h" 12722475a3SStella Laurenzo 13*013b9322SStella Laurenzo #include "Globals.h" 14fcd2969dSzhanghb97 #include "IRModules.h" 15722475a3SStella Laurenzo 1695b77f2eSStella Laurenzo namespace py = pybind11; 17722475a3SStella Laurenzo using namespace mlir; 1895b77f2eSStella Laurenzo using namespace mlir::python; 19722475a3SStella Laurenzo 20*013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 21*013b9322SStella Laurenzo // PyGlobals 22*013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 23*013b9322SStella Laurenzo 24*013b9322SStella Laurenzo PyGlobals *PyGlobals::instance = nullptr; 25*013b9322SStella Laurenzo 26*013b9322SStella Laurenzo PyGlobals::PyGlobals() { 27*013b9322SStella Laurenzo assert(!instance && "PyGlobals already constructed"); 28*013b9322SStella Laurenzo instance = this; 29*013b9322SStella Laurenzo } 30*013b9322SStella Laurenzo 31*013b9322SStella Laurenzo PyGlobals::~PyGlobals() { instance = nullptr; } 32*013b9322SStella Laurenzo 33*013b9322SStella Laurenzo void PyGlobals::loadDialectModule(const std::string &dialectNamespace) { 34*013b9322SStella Laurenzo if (loadedDialectModules.contains(dialectNamespace)) 35*013b9322SStella Laurenzo return; 36*013b9322SStella Laurenzo // Since re-entrancy is possible, make a copy of the search prefixes. 37*013b9322SStella Laurenzo std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 38*013b9322SStella Laurenzo py::object loaded; 39*013b9322SStella Laurenzo for (std::string moduleName : localSearchPrefixes) { 40*013b9322SStella Laurenzo moduleName.push_back('.'); 41*013b9322SStella Laurenzo moduleName.append(dialectNamespace); 42*013b9322SStella Laurenzo 43*013b9322SStella Laurenzo try { 44*013b9322SStella Laurenzo loaded = py::module::import(moduleName.c_str()); 45*013b9322SStella Laurenzo } catch (py::error_already_set &e) { 46*013b9322SStella Laurenzo if (e.matches(PyExc_ModuleNotFoundError)) { 47*013b9322SStella Laurenzo continue; 48*013b9322SStella Laurenzo } else { 49*013b9322SStella Laurenzo throw; 50*013b9322SStella Laurenzo } 51*013b9322SStella Laurenzo } 52*013b9322SStella Laurenzo break; 53*013b9322SStella Laurenzo } 54*013b9322SStella Laurenzo 55*013b9322SStella Laurenzo // Note: Iterator cannot be shared from prior to loading, since re-entrancy 56*013b9322SStella Laurenzo // may have occurred, which may do anything. 57*013b9322SStella Laurenzo loadedDialectModules.insert(dialectNamespace); 58*013b9322SStella Laurenzo } 59*013b9322SStella Laurenzo 60*013b9322SStella Laurenzo void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 61*013b9322SStella Laurenzo py::object pyClass) { 62*013b9322SStella Laurenzo py::object &found = dialectClassMap[dialectNamespace]; 63*013b9322SStella Laurenzo if (found) { 64*013b9322SStella Laurenzo throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + 65*013b9322SStella Laurenzo dialectNamespace + 66*013b9322SStella Laurenzo "' is already registered."); 67*013b9322SStella Laurenzo } 68*013b9322SStella Laurenzo found = std::move(pyClass); 69*013b9322SStella Laurenzo } 70*013b9322SStella Laurenzo 71*013b9322SStella Laurenzo void PyGlobals::registerOperationImpl(const std::string &operationName, 72*013b9322SStella Laurenzo py::object pyClass, py::object rawClass) { 73*013b9322SStella Laurenzo py::object &found = operationClassMap[operationName]; 74*013b9322SStella Laurenzo if (found) { 75*013b9322SStella Laurenzo throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + 76*013b9322SStella Laurenzo operationName + 77*013b9322SStella Laurenzo "' is already registered."); 78*013b9322SStella Laurenzo } 79*013b9322SStella Laurenzo found = std::move(pyClass); 80*013b9322SStella Laurenzo rawOperationClassMap[operationName] = std::move(rawClass); 81*013b9322SStella Laurenzo } 82*013b9322SStella Laurenzo 83*013b9322SStella Laurenzo llvm::Optional<py::object> 84*013b9322SStella Laurenzo PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 85*013b9322SStella Laurenzo loadDialectModule(dialectNamespace); 86*013b9322SStella Laurenzo // Fast match against the class map first (common case). 87*013b9322SStella Laurenzo const auto foundIt = dialectClassMap.find(dialectNamespace); 88*013b9322SStella Laurenzo if (foundIt != dialectClassMap.end()) { 89*013b9322SStella Laurenzo if (foundIt->second.is_none()) 90*013b9322SStella Laurenzo return llvm::None; 91*013b9322SStella Laurenzo assert(foundIt->second && "py::object is defined"); 92*013b9322SStella Laurenzo return foundIt->second; 93*013b9322SStella Laurenzo } 94*013b9322SStella Laurenzo 95*013b9322SStella Laurenzo // Not found and loading did not yield a registration. Negative cache. 96*013b9322SStella Laurenzo dialectClassMap[dialectNamespace] = py::none(); 97*013b9322SStella Laurenzo return llvm::None; 98*013b9322SStella Laurenzo } 99*013b9322SStella Laurenzo 100*013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 101*013b9322SStella Laurenzo // Module initialization. 102*013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 103*013b9322SStella Laurenzo 104722475a3SStella Laurenzo PYBIND11_MODULE(_mlir, m) { 105722475a3SStella Laurenzo m.doc() = "MLIR Python Native Extension"; 106722475a3SStella Laurenzo 107*013b9322SStella Laurenzo py::class_<PyGlobals>(m, "_Globals") 108*013b9322SStella Laurenzo .def_property("dialect_search_modules", 109*013b9322SStella Laurenzo &PyGlobals::getDialectSearchPrefixes, 110*013b9322SStella Laurenzo &PyGlobals::setDialectSearchPrefixes) 111*013b9322SStella Laurenzo .def("append_dialect_search_prefix", 112*013b9322SStella Laurenzo [](PyGlobals &self, std::string moduleName) { 113*013b9322SStella Laurenzo self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 114*013b9322SStella Laurenzo }) 115*013b9322SStella Laurenzo .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 116*013b9322SStella Laurenzo "Testing hook for directly registering a dialect") 117*013b9322SStella Laurenzo .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 118*013b9322SStella Laurenzo "Testing hook for directly registering an operation"); 119*013b9322SStella Laurenzo 120*013b9322SStella Laurenzo // Aside from making the globals accessible to python, having python manage 121*013b9322SStella Laurenzo // it is necessary to make sure it is destroyed (and releases its python 122*013b9322SStella Laurenzo // resources) properly. 123*013b9322SStella Laurenzo m.attr("globals") = 124*013b9322SStella Laurenzo py::cast(new PyGlobals, py::return_value_policy::take_ownership); 125*013b9322SStella Laurenzo 126*013b9322SStella Laurenzo // Registration decorators. 127*013b9322SStella Laurenzo m.def( 128*013b9322SStella Laurenzo "register_dialect", 129*013b9322SStella Laurenzo [](py::object pyClass) { 130*013b9322SStella Laurenzo std::string dialectNamespace = 131*013b9322SStella Laurenzo pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 132*013b9322SStella Laurenzo PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 133*013b9322SStella Laurenzo return pyClass; 134*013b9322SStella Laurenzo }, 135*013b9322SStella Laurenzo "Class decorator for registering a custom Dialect wrapper"); 136*013b9322SStella Laurenzo m.def( 137*013b9322SStella Laurenzo "register_operation", 138*013b9322SStella Laurenzo [](py::object dialectClass) -> py::cpp_function { 139*013b9322SStella Laurenzo return py::cpp_function( 140*013b9322SStella Laurenzo [dialectClass](py::object opClass) -> py::object { 141*013b9322SStella Laurenzo std::string operationName = 142*013b9322SStella Laurenzo opClass.attr("OPERATION_NAME").cast<std::string>(); 143*013b9322SStella Laurenzo auto rawSubclass = PyOpView::createRawSubclass(opClass); 144*013b9322SStella Laurenzo PyGlobals::get().registerOperationImpl(operationName, opClass, 145*013b9322SStella Laurenzo rawSubclass); 146*013b9322SStella Laurenzo 147*013b9322SStella Laurenzo // Dict-stuff the new opClass by name onto the dialect class. 148*013b9322SStella Laurenzo py::object opClassName = opClass.attr("__name__"); 149*013b9322SStella Laurenzo dialectClass.attr(opClassName) = opClass; 150*013b9322SStella Laurenzo 151*013b9322SStella Laurenzo // Now create a special "Raw" subclass that passes through 152*013b9322SStella Laurenzo // construction to the OpView parent (bypasses the intermediate 153*013b9322SStella Laurenzo // child's __init__). 154*013b9322SStella Laurenzo opClass.attr("_Raw") = rawSubclass; 155*013b9322SStella Laurenzo return opClass; 156*013b9322SStella Laurenzo }); 157*013b9322SStella Laurenzo }, 158*013b9322SStella Laurenzo "Class decorator for registering a custom Operation wrapper"); 159*013b9322SStella Laurenzo 160fcd2969dSzhanghb97 // Define and populate IR submodule. 161fcd2969dSzhanghb97 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 162fcd2969dSzhanghb97 populateIRSubmodule(irModule); 163722475a3SStella Laurenzo } 164