1 //===- MainModule.cpp - Main pybind module --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <tuple> 10 11 #include "PybindUtils.h" 12 13 #include "Globals.h" 14 #include "IRModules.h" 15 #include "Pass.h" 16 17 namespace py = pybind11; 18 using namespace mlir; 19 using namespace mlir::python; 20 21 // ----------------------------------------------------------------------------- 22 // PyGlobals 23 // ----------------------------------------------------------------------------- 24 25 PyGlobals *PyGlobals::instance = nullptr; 26 27 PyGlobals::PyGlobals() { 28 assert(!instance && "PyGlobals already constructed"); 29 instance = this; 30 } 31 32 PyGlobals::~PyGlobals() { instance = nullptr; } 33 34 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 35 py::gil_scoped_acquire(); 36 if (loadedDialectModulesCache.contains(dialectNamespace)) 37 return; 38 // Since re-entrancy is possible, make a copy of the search prefixes. 39 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 40 py::object loaded; 41 for (std::string moduleName : localSearchPrefixes) { 42 moduleName.push_back('.'); 43 moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 44 45 try { 46 py::gil_scoped_release(); 47 loaded = py::module::import(moduleName.c_str()); 48 } catch (py::error_already_set &e) { 49 if (e.matches(PyExc_ModuleNotFoundError)) { 50 continue; 51 } else { 52 throw; 53 } 54 } 55 break; 56 } 57 58 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 59 // may have occurred, which may do anything. 60 loadedDialectModulesCache.insert(dialectNamespace); 61 } 62 63 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 64 py::object pyClass) { 65 py::gil_scoped_acquire(); 66 py::object &found = dialectClassMap[dialectNamespace]; 67 if (found) { 68 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + 69 dialectNamespace + 70 "' is already registered."); 71 } 72 found = std::move(pyClass); 73 } 74 75 void PyGlobals::registerOperationImpl(const std::string &operationName, 76 py::object pyClass, 77 py::object rawOpViewClass) { 78 py::gil_scoped_acquire(); 79 py::object &found = operationClassMap[operationName]; 80 if (found) { 81 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + 82 operationName + 83 "' is already registered."); 84 } 85 found = std::move(pyClass); 86 rawOpViewClassMap[operationName] = std::move(rawOpViewClass); 87 } 88 89 llvm::Optional<py::object> 90 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 91 py::gil_scoped_acquire(); 92 loadDialectModule(dialectNamespace); 93 // Fast match against the class map first (common case). 94 const auto foundIt = dialectClassMap.find(dialectNamespace); 95 if (foundIt != dialectClassMap.end()) { 96 if (foundIt->second.is_none()) 97 return llvm::None; 98 assert(foundIt->second && "py::object is defined"); 99 return foundIt->second; 100 } 101 102 // Not found and loading did not yield a registration. Negative cache. 103 dialectClassMap[dialectNamespace] = py::none(); 104 return llvm::None; 105 } 106 107 llvm::Optional<pybind11::object> 108 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { 109 { 110 py::gil_scoped_acquire(); 111 auto foundIt = rawOpViewClassMapCache.find(operationName); 112 if (foundIt != rawOpViewClassMapCache.end()) { 113 if (foundIt->second.is_none()) 114 return llvm::None; 115 assert(foundIt->second && "py::object is defined"); 116 return foundIt->second; 117 } 118 } 119 120 // Not found. Load the dialect namespace. 121 auto split = operationName.split('.'); 122 llvm::StringRef dialectNamespace = split.first; 123 loadDialectModule(dialectNamespace); 124 125 // Attempt to find from the canonical map and cache. 126 { 127 py::gil_scoped_acquire(); 128 auto foundIt = rawOpViewClassMap.find(operationName); 129 if (foundIt != rawOpViewClassMap.end()) { 130 if (foundIt->second.is_none()) 131 return llvm::None; 132 assert(foundIt->second && "py::object is defined"); 133 // Positive cache. 134 rawOpViewClassMapCache[operationName] = foundIt->second; 135 return foundIt->second; 136 } else { 137 // Negative cache. 138 rawOpViewClassMap[operationName] = py::none(); 139 return llvm::None; 140 } 141 } 142 } 143 144 void PyGlobals::clearImportCache() { 145 py::gil_scoped_acquire(); 146 loadedDialectModulesCache.clear(); 147 rawOpViewClassMapCache.clear(); 148 } 149 150 // ----------------------------------------------------------------------------- 151 // Module initialization. 152 // ----------------------------------------------------------------------------- 153 154 PYBIND11_MODULE(_mlir, m) { 155 m.doc() = "MLIR Python Native Extension"; 156 157 py::class_<PyGlobals>(m, "_Globals") 158 .def_property("dialect_search_modules", 159 &PyGlobals::getDialectSearchPrefixes, 160 &PyGlobals::setDialectSearchPrefixes) 161 .def("append_dialect_search_prefix", 162 [](PyGlobals &self, std::string moduleName) { 163 self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 164 self.clearImportCache(); 165 }) 166 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 167 "Testing hook for directly registering a dialect") 168 .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 169 "Testing hook for directly registering an operation"); 170 171 // Aside from making the globals accessible to python, having python manage 172 // it is necessary to make sure it is destroyed (and releases its python 173 // resources) properly. 174 m.attr("globals") = 175 py::cast(new PyGlobals, py::return_value_policy::take_ownership); 176 177 // Registration decorators. 178 m.def( 179 "register_dialect", 180 [](py::object pyClass) { 181 std::string dialectNamespace = 182 pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 183 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 184 return pyClass; 185 }, 186 "Class decorator for registering a custom Dialect wrapper"); 187 m.def( 188 "register_operation", 189 [](py::object dialectClass) -> py::cpp_function { 190 return py::cpp_function( 191 [dialectClass](py::object opClass) -> py::object { 192 std::string operationName = 193 opClass.attr("OPERATION_NAME").cast<std::string>(); 194 auto rawSubclass = PyOpView::createRawSubclass(opClass); 195 PyGlobals::get().registerOperationImpl(operationName, opClass, 196 rawSubclass); 197 198 // Dict-stuff the new opClass by name onto the dialect class. 199 py::object opClassName = opClass.attr("__name__"); 200 dialectClass.attr(opClassName) = opClass; 201 202 // Now create a special "Raw" subclass that passes through 203 // construction to the OpView parent (bypasses the intermediate 204 // child's __init__). 205 opClass.attr("_Raw") = rawSubclass; 206 return opClass; 207 }); 208 }, 209 "Class decorator for registering a custom Operation wrapper"); 210 211 // Define and populate IR submodule. 212 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 213 populateIRSubmodule(irModule); 214 215 // Define and populate PassManager submodule. 216 auto passModule = 217 m.def_submodule("passmanager", "MLIR Pass Management Bindings"); 218 populatePassManagerSubmodule(passModule); 219 } 220