1 //===- MainModule.cpp - Main pybind module --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <tuple> 10 11 #include "PybindUtils.h" 12 13 #include "Globals.h" 14 #include "IRModules.h" 15 16 namespace py = pybind11; 17 using namespace mlir; 18 using namespace mlir::python; 19 20 // ----------------------------------------------------------------------------- 21 // PyGlobals 22 // ----------------------------------------------------------------------------- 23 24 PyGlobals *PyGlobals::instance = nullptr; 25 26 PyGlobals::PyGlobals() { 27 assert(!instance && "PyGlobals already constructed"); 28 instance = this; 29 } 30 31 PyGlobals::~PyGlobals() { instance = nullptr; } 32 33 void PyGlobals::loadDialectModule(const std::string &dialectNamespace) { 34 if (loadedDialectModules.contains(dialectNamespace)) 35 return; 36 // Since re-entrancy is possible, make a copy of the search prefixes. 37 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 38 py::object loaded; 39 for (std::string moduleName : localSearchPrefixes) { 40 moduleName.push_back('.'); 41 moduleName.append(dialectNamespace); 42 43 try { 44 loaded = py::module::import(moduleName.c_str()); 45 } catch (py::error_already_set &e) { 46 if (e.matches(PyExc_ModuleNotFoundError)) { 47 continue; 48 } else { 49 throw; 50 } 51 } 52 break; 53 } 54 55 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 56 // may have occurred, which may do anything. 57 loadedDialectModules.insert(dialectNamespace); 58 } 59 60 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 61 py::object pyClass) { 62 py::object &found = dialectClassMap[dialectNamespace]; 63 if (found) { 64 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + 65 dialectNamespace + 66 "' is already registered."); 67 } 68 found = std::move(pyClass); 69 } 70 71 void PyGlobals::registerOperationImpl(const std::string &operationName, 72 py::object pyClass, py::object rawClass) { 73 py::object &found = operationClassMap[operationName]; 74 if (found) { 75 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + 76 operationName + 77 "' is already registered."); 78 } 79 found = std::move(pyClass); 80 rawOperationClassMap[operationName] = std::move(rawClass); 81 } 82 83 llvm::Optional<py::object> 84 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 85 loadDialectModule(dialectNamespace); 86 // Fast match against the class map first (common case). 87 const auto foundIt = dialectClassMap.find(dialectNamespace); 88 if (foundIt != dialectClassMap.end()) { 89 if (foundIt->second.is_none()) 90 return llvm::None; 91 assert(foundIt->second && "py::object is defined"); 92 return foundIt->second; 93 } 94 95 // Not found and loading did not yield a registration. Negative cache. 96 dialectClassMap[dialectNamespace] = py::none(); 97 return llvm::None; 98 } 99 100 // ----------------------------------------------------------------------------- 101 // Module initialization. 102 // ----------------------------------------------------------------------------- 103 104 PYBIND11_MODULE(_mlir, m) { 105 m.doc() = "MLIR Python Native Extension"; 106 107 py::class_<PyGlobals>(m, "_Globals") 108 .def_property("dialect_search_modules", 109 &PyGlobals::getDialectSearchPrefixes, 110 &PyGlobals::setDialectSearchPrefixes) 111 .def("append_dialect_search_prefix", 112 [](PyGlobals &self, std::string moduleName) { 113 self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 114 }) 115 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 116 "Testing hook for directly registering a dialect") 117 .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 118 "Testing hook for directly registering an operation"); 119 120 // Aside from making the globals accessible to python, having python manage 121 // it is necessary to make sure it is destroyed (and releases its python 122 // resources) properly. 123 m.attr("globals") = 124 py::cast(new PyGlobals, py::return_value_policy::take_ownership); 125 126 // Registration decorators. 127 m.def( 128 "register_dialect", 129 [](py::object pyClass) { 130 std::string dialectNamespace = 131 pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 132 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 133 return pyClass; 134 }, 135 "Class decorator for registering a custom Dialect wrapper"); 136 m.def( 137 "register_operation", 138 [](py::object dialectClass) -> py::cpp_function { 139 return py::cpp_function( 140 [dialectClass](py::object opClass) -> py::object { 141 std::string operationName = 142 opClass.attr("OPERATION_NAME").cast<std::string>(); 143 auto rawSubclass = PyOpView::createRawSubclass(opClass); 144 PyGlobals::get().registerOperationImpl(operationName, opClass, 145 rawSubclass); 146 147 // Dict-stuff the new opClass by name onto the dialect class. 148 py::object opClassName = opClass.attr("__name__"); 149 dialectClass.attr(opClassName) = opClass; 150 151 // Now create a special "Raw" subclass that passes through 152 // construction to the OpView parent (bypasses the intermediate 153 // child's __init__). 154 opClass.attr("_Raw") = rawSubclass; 155 return opClass; 156 }); 157 }, 158 "Class decorator for registering a custom Operation wrapper"); 159 160 // Define and populate IR submodule. 161 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 162 populateIRSubmodule(irModule); 163 } 164