1722475a3SStella Laurenzo //===- MainModule.cpp - Main pybind module --------------------------------===// 2722475a3SStella Laurenzo // 3722475a3SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4722475a3SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5722475a3SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6722475a3SStella Laurenzo // 7722475a3SStella Laurenzo //===----------------------------------------------------------------------===// 8722475a3SStella Laurenzo 9722475a3SStella Laurenzo #include <tuple> 10722475a3SStella Laurenzo 11013b9322SStella Laurenzo #include "PybindUtils.h" 12722475a3SStella Laurenzo 13013b9322SStella Laurenzo #include "Globals.h" 14fcd2969dSzhanghb97 #include "IRModules.h" 15722475a3SStella Laurenzo 1695b77f2eSStella Laurenzo namespace py = pybind11; 17722475a3SStella Laurenzo using namespace mlir; 1895b77f2eSStella Laurenzo using namespace mlir::python; 19722475a3SStella Laurenzo 20013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 21013b9322SStella Laurenzo // PyGlobals 22013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 23013b9322SStella Laurenzo 24013b9322SStella Laurenzo PyGlobals *PyGlobals::instance = nullptr; 25013b9322SStella Laurenzo 26013b9322SStella Laurenzo PyGlobals::PyGlobals() { 27013b9322SStella Laurenzo assert(!instance && "PyGlobals already constructed"); 28013b9322SStella Laurenzo instance = this; 29013b9322SStella Laurenzo } 30013b9322SStella Laurenzo 31013b9322SStella Laurenzo PyGlobals::~PyGlobals() { instance = nullptr; } 32013b9322SStella Laurenzo 33*8260db75SStella Laurenzo void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 34*8260db75SStella Laurenzo py::gil_scoped_acquire(); 35*8260db75SStella Laurenzo if (loadedDialectModulesCache.contains(dialectNamespace)) 36013b9322SStella Laurenzo return; 37013b9322SStella Laurenzo // Since re-entrancy is possible, make a copy of the search prefixes. 38013b9322SStella Laurenzo std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 39013b9322SStella Laurenzo py::object loaded; 40013b9322SStella Laurenzo for (std::string moduleName : localSearchPrefixes) { 41013b9322SStella Laurenzo moduleName.push_back('.'); 42*8260db75SStella Laurenzo moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 43013b9322SStella Laurenzo 44013b9322SStella Laurenzo try { 45*8260db75SStella Laurenzo py::gil_scoped_release(); 46013b9322SStella Laurenzo loaded = py::module::import(moduleName.c_str()); 47013b9322SStella Laurenzo } catch (py::error_already_set &e) { 48013b9322SStella Laurenzo if (e.matches(PyExc_ModuleNotFoundError)) { 49013b9322SStella Laurenzo continue; 50013b9322SStella Laurenzo } else { 51013b9322SStella Laurenzo throw; 52013b9322SStella Laurenzo } 53013b9322SStella Laurenzo } 54013b9322SStella Laurenzo break; 55013b9322SStella Laurenzo } 56013b9322SStella Laurenzo 57013b9322SStella Laurenzo // Note: Iterator cannot be shared from prior to loading, since re-entrancy 58013b9322SStella Laurenzo // may have occurred, which may do anything. 59*8260db75SStella Laurenzo loadedDialectModulesCache.insert(dialectNamespace); 60013b9322SStella Laurenzo } 61013b9322SStella Laurenzo 62013b9322SStella Laurenzo void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 63013b9322SStella Laurenzo py::object pyClass) { 64*8260db75SStella Laurenzo py::gil_scoped_acquire(); 65013b9322SStella Laurenzo py::object &found = dialectClassMap[dialectNamespace]; 66013b9322SStella Laurenzo if (found) { 67013b9322SStella Laurenzo throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + 68013b9322SStella Laurenzo dialectNamespace + 69013b9322SStella Laurenzo "' is already registered."); 70013b9322SStella Laurenzo } 71013b9322SStella Laurenzo found = std::move(pyClass); 72013b9322SStella Laurenzo } 73013b9322SStella Laurenzo 74013b9322SStella Laurenzo void PyGlobals::registerOperationImpl(const std::string &operationName, 75*8260db75SStella Laurenzo py::object pyClass, 76*8260db75SStella Laurenzo py::object rawOpViewClass) { 77*8260db75SStella Laurenzo py::gil_scoped_acquire(); 78013b9322SStella Laurenzo py::object &found = operationClassMap[operationName]; 79013b9322SStella Laurenzo if (found) { 80013b9322SStella Laurenzo throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + 81013b9322SStella Laurenzo operationName + 82013b9322SStella Laurenzo "' is already registered."); 83013b9322SStella Laurenzo } 84013b9322SStella Laurenzo found = std::move(pyClass); 85*8260db75SStella Laurenzo rawOpViewClassMap[operationName] = std::move(rawOpViewClass); 86013b9322SStella Laurenzo } 87013b9322SStella Laurenzo 88013b9322SStella Laurenzo llvm::Optional<py::object> 89013b9322SStella Laurenzo PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 90*8260db75SStella Laurenzo py::gil_scoped_acquire(); 91013b9322SStella Laurenzo loadDialectModule(dialectNamespace); 92013b9322SStella Laurenzo // Fast match against the class map first (common case). 93013b9322SStella Laurenzo const auto foundIt = dialectClassMap.find(dialectNamespace); 94013b9322SStella Laurenzo if (foundIt != dialectClassMap.end()) { 95013b9322SStella Laurenzo if (foundIt->second.is_none()) 96013b9322SStella Laurenzo return llvm::None; 97013b9322SStella Laurenzo assert(foundIt->second && "py::object is defined"); 98013b9322SStella Laurenzo return foundIt->second; 99013b9322SStella Laurenzo } 100013b9322SStella Laurenzo 101013b9322SStella Laurenzo // Not found and loading did not yield a registration. Negative cache. 102013b9322SStella Laurenzo dialectClassMap[dialectNamespace] = py::none(); 103013b9322SStella Laurenzo return llvm::None; 104013b9322SStella Laurenzo } 105013b9322SStella Laurenzo 106*8260db75SStella Laurenzo llvm::Optional<pybind11::object> 107*8260db75SStella Laurenzo PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { 108*8260db75SStella Laurenzo { 109*8260db75SStella Laurenzo py::gil_scoped_acquire(); 110*8260db75SStella Laurenzo auto foundIt = rawOpViewClassMapCache.find(operationName); 111*8260db75SStella Laurenzo if (foundIt != rawOpViewClassMapCache.end()) { 112*8260db75SStella Laurenzo if (foundIt->second.is_none()) 113*8260db75SStella Laurenzo return llvm::None; 114*8260db75SStella Laurenzo assert(foundIt->second && "py::object is defined"); 115*8260db75SStella Laurenzo return foundIt->second; 116*8260db75SStella Laurenzo } 117*8260db75SStella Laurenzo } 118*8260db75SStella Laurenzo 119*8260db75SStella Laurenzo // Not found. Load the dialect namespace. 120*8260db75SStella Laurenzo auto split = operationName.split('.'); 121*8260db75SStella Laurenzo llvm::StringRef dialectNamespace = split.first; 122*8260db75SStella Laurenzo loadDialectModule(dialectNamespace); 123*8260db75SStella Laurenzo 124*8260db75SStella Laurenzo // Attempt to find from the canonical map and cache. 125*8260db75SStella Laurenzo { 126*8260db75SStella Laurenzo py::gil_scoped_acquire(); 127*8260db75SStella Laurenzo auto foundIt = rawOpViewClassMap.find(operationName); 128*8260db75SStella Laurenzo if (foundIt != rawOpViewClassMap.end()) { 129*8260db75SStella Laurenzo if (foundIt->second.is_none()) 130*8260db75SStella Laurenzo return llvm::None; 131*8260db75SStella Laurenzo assert(foundIt->second && "py::object is defined"); 132*8260db75SStella Laurenzo // Positive cache. 133*8260db75SStella Laurenzo rawOpViewClassMapCache[operationName] = foundIt->second; 134*8260db75SStella Laurenzo return foundIt->second; 135*8260db75SStella Laurenzo } else { 136*8260db75SStella Laurenzo // Negative cache. 137*8260db75SStella Laurenzo rawOpViewClassMap[operationName] = py::none(); 138*8260db75SStella Laurenzo return llvm::None; 139*8260db75SStella Laurenzo } 140*8260db75SStella Laurenzo } 141*8260db75SStella Laurenzo } 142*8260db75SStella Laurenzo 143*8260db75SStella Laurenzo void PyGlobals::clearImportCache() { 144*8260db75SStella Laurenzo py::gil_scoped_acquire(); 145*8260db75SStella Laurenzo loadedDialectModulesCache.clear(); 146*8260db75SStella Laurenzo rawOpViewClassMapCache.clear(); 147*8260db75SStella Laurenzo } 148*8260db75SStella Laurenzo 149013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 150013b9322SStella Laurenzo // Module initialization. 151013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 152013b9322SStella Laurenzo 153722475a3SStella Laurenzo PYBIND11_MODULE(_mlir, m) { 154722475a3SStella Laurenzo m.doc() = "MLIR Python Native Extension"; 155722475a3SStella Laurenzo 156013b9322SStella Laurenzo py::class_<PyGlobals>(m, "_Globals") 157013b9322SStella Laurenzo .def_property("dialect_search_modules", 158013b9322SStella Laurenzo &PyGlobals::getDialectSearchPrefixes, 159013b9322SStella Laurenzo &PyGlobals::setDialectSearchPrefixes) 160013b9322SStella Laurenzo .def("append_dialect_search_prefix", 161013b9322SStella Laurenzo [](PyGlobals &self, std::string moduleName) { 162013b9322SStella Laurenzo self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 163*8260db75SStella Laurenzo self.clearImportCache(); 164013b9322SStella Laurenzo }) 165013b9322SStella Laurenzo .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 166013b9322SStella Laurenzo "Testing hook for directly registering a dialect") 167013b9322SStella Laurenzo .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 168013b9322SStella Laurenzo "Testing hook for directly registering an operation"); 169013b9322SStella Laurenzo 170013b9322SStella Laurenzo // Aside from making the globals accessible to python, having python manage 171013b9322SStella Laurenzo // it is necessary to make sure it is destroyed (and releases its python 172013b9322SStella Laurenzo // resources) properly. 173013b9322SStella Laurenzo m.attr("globals") = 174013b9322SStella Laurenzo py::cast(new PyGlobals, py::return_value_policy::take_ownership); 175013b9322SStella Laurenzo 176013b9322SStella Laurenzo // Registration decorators. 177013b9322SStella Laurenzo m.def( 178013b9322SStella Laurenzo "register_dialect", 179013b9322SStella Laurenzo [](py::object pyClass) { 180013b9322SStella Laurenzo std::string dialectNamespace = 181013b9322SStella Laurenzo pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 182013b9322SStella Laurenzo PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 183013b9322SStella Laurenzo return pyClass; 184013b9322SStella Laurenzo }, 185013b9322SStella Laurenzo "Class decorator for registering a custom Dialect wrapper"); 186013b9322SStella Laurenzo m.def( 187013b9322SStella Laurenzo "register_operation", 188013b9322SStella Laurenzo [](py::object dialectClass) -> py::cpp_function { 189013b9322SStella Laurenzo return py::cpp_function( 190013b9322SStella Laurenzo [dialectClass](py::object opClass) -> py::object { 191013b9322SStella Laurenzo std::string operationName = 192013b9322SStella Laurenzo opClass.attr("OPERATION_NAME").cast<std::string>(); 193013b9322SStella Laurenzo auto rawSubclass = PyOpView::createRawSubclass(opClass); 194013b9322SStella Laurenzo PyGlobals::get().registerOperationImpl(operationName, opClass, 195013b9322SStella Laurenzo rawSubclass); 196013b9322SStella Laurenzo 197013b9322SStella Laurenzo // Dict-stuff the new opClass by name onto the dialect class. 198013b9322SStella Laurenzo py::object opClassName = opClass.attr("__name__"); 199013b9322SStella Laurenzo dialectClass.attr(opClassName) = opClass; 200013b9322SStella Laurenzo 201013b9322SStella Laurenzo // Now create a special "Raw" subclass that passes through 202013b9322SStella Laurenzo // construction to the OpView parent (bypasses the intermediate 203013b9322SStella Laurenzo // child's __init__). 204013b9322SStella Laurenzo opClass.attr("_Raw") = rawSubclass; 205013b9322SStella Laurenzo return opClass; 206013b9322SStella Laurenzo }); 207013b9322SStella Laurenzo }, 208013b9322SStella Laurenzo "Class decorator for registering a custom Operation wrapper"); 209013b9322SStella Laurenzo 210fcd2969dSzhanghb97 // Define and populate IR submodule. 211fcd2969dSzhanghb97 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 212fcd2969dSzhanghb97 populateIRSubmodule(irModule); 213722475a3SStella Laurenzo } 214