1 //===- IRModule.cpp - IR pybind module ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 #include "Globals.h" 11 #include "PybindUtils.h" 12 13 #include <vector> 14 15 #include "mlir-c/Bindings/Python/Interop.h" 16 17 namespace py = pybind11; 18 using namespace mlir; 19 using namespace mlir::python; 20 21 // ----------------------------------------------------------------------------- 22 // PyGlobals 23 // ----------------------------------------------------------------------------- 24 25 PyGlobals *PyGlobals::instance = nullptr; 26 27 PyGlobals::PyGlobals() { 28 assert(!instance && "PyGlobals already constructed"); 29 instance = this; 30 // The default search path include {mlir.}dialects, where {mlir.} is the 31 // package prefix configured at compile time. 32 dialectSearchPrefixes.push_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); 33 } 34 35 PyGlobals::~PyGlobals() { instance = nullptr; } 36 37 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 38 py::gil_scoped_acquire(); 39 if (loadedDialectModulesCache.contains(dialectNamespace)) 40 return; 41 // Since re-entrancy is possible, make a copy of the search prefixes. 42 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 43 py::object loaded; 44 for (std::string moduleName : localSearchPrefixes) { 45 moduleName.push_back('.'); 46 moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 47 48 try { 49 py::gil_scoped_release(); 50 loaded = py::module::import(moduleName.c_str()); 51 } catch (py::error_already_set &e) { 52 if (e.matches(PyExc_ModuleNotFoundError)) { 53 continue; 54 } else { 55 throw; 56 } 57 } 58 break; 59 } 60 61 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 62 // may have occurred, which may do anything. 63 loadedDialectModulesCache.insert(dialectNamespace); 64 } 65 66 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 67 py::object pyClass) { 68 py::gil_scoped_acquire(); 69 py::object &found = dialectClassMap[dialectNamespace]; 70 if (found) { 71 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + 72 dialectNamespace + 73 "' is already registered."); 74 } 75 found = std::move(pyClass); 76 } 77 78 void PyGlobals::registerOperationImpl(const std::string &operationName, 79 py::object pyClass, 80 py::object rawOpViewClass) { 81 py::gil_scoped_acquire(); 82 py::object &found = operationClassMap[operationName]; 83 if (found) { 84 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + 85 operationName + 86 "' is already registered."); 87 } 88 found = std::move(pyClass); 89 rawOpViewClassMap[operationName] = std::move(rawOpViewClass); 90 } 91 92 llvm::Optional<py::object> 93 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 94 py::gil_scoped_acquire(); 95 loadDialectModule(dialectNamespace); 96 // Fast match against the class map first (common case). 97 const auto foundIt = dialectClassMap.find(dialectNamespace); 98 if (foundIt != dialectClassMap.end()) { 99 if (foundIt->second.is_none()) 100 return llvm::None; 101 assert(foundIt->second && "py::object is defined"); 102 return foundIt->second; 103 } 104 105 // Not found and loading did not yield a registration. Negative cache. 106 dialectClassMap[dialectNamespace] = py::none(); 107 return llvm::None; 108 } 109 110 llvm::Optional<pybind11::object> 111 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { 112 { 113 py::gil_scoped_acquire(); 114 auto foundIt = rawOpViewClassMapCache.find(operationName); 115 if (foundIt != rawOpViewClassMapCache.end()) { 116 if (foundIt->second.is_none()) 117 return llvm::None; 118 assert(foundIt->second && "py::object is defined"); 119 return foundIt->second; 120 } 121 } 122 123 // Not found. Load the dialect namespace. 124 auto split = operationName.split('.'); 125 llvm::StringRef dialectNamespace = split.first; 126 loadDialectModule(dialectNamespace); 127 128 // Attempt to find from the canonical map and cache. 129 { 130 py::gil_scoped_acquire(); 131 auto foundIt = rawOpViewClassMap.find(operationName); 132 if (foundIt != rawOpViewClassMap.end()) { 133 if (foundIt->second.is_none()) 134 return llvm::None; 135 assert(foundIt->second && "py::object is defined"); 136 // Positive cache. 137 rawOpViewClassMapCache[operationName] = foundIt->second; 138 return foundIt->second; 139 } else { 140 // Negative cache. 141 rawOpViewClassMap[operationName] = py::none(); 142 return llvm::None; 143 } 144 } 145 } 146 147 void PyGlobals::clearImportCache() { 148 py::gil_scoped_acquire(); 149 loadedDialectModulesCache.clear(); 150 rawOpViewClassMapCache.clear(); 151 } 152