1 //===- MainModule.cpp - Main pybind module --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <tuple> 10 11 #include "PybindUtils.h" 12 13 #include "Globals.h" 14 #include "IRModules.h" 15 16 namespace py = pybind11; 17 using namespace mlir; 18 using namespace mlir::python; 19 20 // ----------------------------------------------------------------------------- 21 // PyGlobals 22 // ----------------------------------------------------------------------------- 23 24 PyGlobals *PyGlobals::instance = nullptr; 25 26 PyGlobals::PyGlobals() { 27 assert(!instance && "PyGlobals already constructed"); 28 instance = this; 29 } 30 31 PyGlobals::~PyGlobals() { instance = nullptr; } 32 33 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 34 py::gil_scoped_acquire(); 35 if (loadedDialectModulesCache.contains(dialectNamespace)) 36 return; 37 // Since re-entrancy is possible, make a copy of the search prefixes. 38 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 39 py::object loaded; 40 for (std::string moduleName : localSearchPrefixes) { 41 moduleName.push_back('.'); 42 moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 43 44 try { 45 py::gil_scoped_release(); 46 loaded = py::module::import(moduleName.c_str()); 47 } catch (py::error_already_set &e) { 48 if (e.matches(PyExc_ModuleNotFoundError)) { 49 continue; 50 } else { 51 throw; 52 } 53 } 54 break; 55 } 56 57 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 58 // may have occurred, which may do anything. 59 loadedDialectModulesCache.insert(dialectNamespace); 60 } 61 62 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 63 py::object pyClass) { 64 py::gil_scoped_acquire(); 65 py::object &found = dialectClassMap[dialectNamespace]; 66 if (found) { 67 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + 68 dialectNamespace + 69 "' is already registered."); 70 } 71 found = std::move(pyClass); 72 } 73 74 void PyGlobals::registerOperationImpl(const std::string &operationName, 75 py::object pyClass, 76 py::object rawOpViewClass) { 77 py::gil_scoped_acquire(); 78 py::object &found = operationClassMap[operationName]; 79 if (found) { 80 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + 81 operationName + 82 "' is already registered."); 83 } 84 found = std::move(pyClass); 85 rawOpViewClassMap[operationName] = std::move(rawOpViewClass); 86 } 87 88 llvm::Optional<py::object> 89 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 90 py::gil_scoped_acquire(); 91 loadDialectModule(dialectNamespace); 92 // Fast match against the class map first (common case). 93 const auto foundIt = dialectClassMap.find(dialectNamespace); 94 if (foundIt != dialectClassMap.end()) { 95 if (foundIt->second.is_none()) 96 return llvm::None; 97 assert(foundIt->second && "py::object is defined"); 98 return foundIt->second; 99 } 100 101 // Not found and loading did not yield a registration. Negative cache. 102 dialectClassMap[dialectNamespace] = py::none(); 103 return llvm::None; 104 } 105 106 llvm::Optional<pybind11::object> 107 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { 108 { 109 py::gil_scoped_acquire(); 110 auto foundIt = rawOpViewClassMapCache.find(operationName); 111 if (foundIt != rawOpViewClassMapCache.end()) { 112 if (foundIt->second.is_none()) 113 return llvm::None; 114 assert(foundIt->second && "py::object is defined"); 115 return foundIt->second; 116 } 117 } 118 119 // Not found. Load the dialect namespace. 120 auto split = operationName.split('.'); 121 llvm::StringRef dialectNamespace = split.first; 122 loadDialectModule(dialectNamespace); 123 124 // Attempt to find from the canonical map and cache. 125 { 126 py::gil_scoped_acquire(); 127 auto foundIt = rawOpViewClassMap.find(operationName); 128 if (foundIt != rawOpViewClassMap.end()) { 129 if (foundIt->second.is_none()) 130 return llvm::None; 131 assert(foundIt->second && "py::object is defined"); 132 // Positive cache. 133 rawOpViewClassMapCache[operationName] = foundIt->second; 134 return foundIt->second; 135 } else { 136 // Negative cache. 137 rawOpViewClassMap[operationName] = py::none(); 138 return llvm::None; 139 } 140 } 141 } 142 143 void PyGlobals::clearImportCache() { 144 py::gil_scoped_acquire(); 145 loadedDialectModulesCache.clear(); 146 rawOpViewClassMapCache.clear(); 147 } 148 149 // ----------------------------------------------------------------------------- 150 // Module initialization. 151 // ----------------------------------------------------------------------------- 152 153 PYBIND11_MODULE(_mlir, m) { 154 m.doc() = "MLIR Python Native Extension"; 155 156 py::class_<PyGlobals>(m, "_Globals") 157 .def_property("dialect_search_modules", 158 &PyGlobals::getDialectSearchPrefixes, 159 &PyGlobals::setDialectSearchPrefixes) 160 .def("append_dialect_search_prefix", 161 [](PyGlobals &self, std::string moduleName) { 162 self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 163 self.clearImportCache(); 164 }) 165 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 166 "Testing hook for directly registering a dialect") 167 .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 168 "Testing hook for directly registering an operation"); 169 170 // Aside from making the globals accessible to python, having python manage 171 // it is necessary to make sure it is destroyed (and releases its python 172 // resources) properly. 173 m.attr("globals") = 174 py::cast(new PyGlobals, py::return_value_policy::take_ownership); 175 176 // Registration decorators. 177 m.def( 178 "register_dialect", 179 [](py::object pyClass) { 180 std::string dialectNamespace = 181 pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 182 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 183 return pyClass; 184 }, 185 "Class decorator for registering a custom Dialect wrapper"); 186 m.def( 187 "register_operation", 188 [](py::object dialectClass) -> py::cpp_function { 189 return py::cpp_function( 190 [dialectClass](py::object opClass) -> py::object { 191 std::string operationName = 192 opClass.attr("OPERATION_NAME").cast<std::string>(); 193 auto rawSubclass = PyOpView::createRawSubclass(opClass); 194 PyGlobals::get().registerOperationImpl(operationName, opClass, 195 rawSubclass); 196 197 // Dict-stuff the new opClass by name onto the dialect class. 198 py::object opClassName = opClass.attr("__name__"); 199 dialectClass.attr(opClassName) = opClass; 200 201 // Now create a special "Raw" subclass that passes through 202 // construction to the OpView parent (bypasses the intermediate 203 // child's __init__). 204 opClass.attr("_Raw") = rawSubclass; 205 return opClass; 206 }); 207 }, 208 "Class decorator for registering a custom Operation wrapper"); 209 210 // Define and populate IR submodule. 211 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 212 populateIRSubmodule(irModule); 213 } 214