1 //===- IRModule.cpp - IR pybind module ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 #include "Globals.h" 11 #include "PybindUtils.h" 12 13 #include <vector> 14 15 namespace py = pybind11; 16 using namespace mlir; 17 using namespace mlir::python; 18 19 // ----------------------------------------------------------------------------- 20 // PyGlobals 21 // ----------------------------------------------------------------------------- 22 23 PyGlobals *PyGlobals::instance = nullptr; 24 25 PyGlobals::PyGlobals() { 26 assert(!instance && "PyGlobals already constructed"); 27 instance = this; 28 } 29 30 PyGlobals::~PyGlobals() { instance = nullptr; } 31 32 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 33 py::gil_scoped_acquire(); 34 if (loadedDialectModulesCache.contains(dialectNamespace)) 35 return; 36 // Since re-entrancy is possible, make a copy of the search prefixes. 37 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 38 py::object loaded; 39 for (std::string moduleName : localSearchPrefixes) { 40 moduleName.push_back('.'); 41 moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 42 43 try { 44 py::gil_scoped_release(); 45 loaded = py::module::import(moduleName.c_str()); 46 } catch (py::error_already_set &e) { 47 if (e.matches(PyExc_ModuleNotFoundError)) { 48 continue; 49 } else { 50 throw; 51 } 52 } 53 break; 54 } 55 56 // Note: Iterator cannot be shared from prior to loading, since re-entrancy 57 // may have occurred, which may do anything. 58 loadedDialectModulesCache.insert(dialectNamespace); 59 } 60 61 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 62 py::object pyClass) { 63 py::gil_scoped_acquire(); 64 py::object &found = dialectClassMap[dialectNamespace]; 65 if (found) { 66 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + 67 dialectNamespace + 68 "' is already registered."); 69 } 70 found = std::move(pyClass); 71 } 72 73 void PyGlobals::registerOperationImpl(const std::string &operationName, 74 py::object pyClass, 75 py::object rawOpViewClass) { 76 py::gil_scoped_acquire(); 77 py::object &found = operationClassMap[operationName]; 78 if (found) { 79 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + 80 operationName + 81 "' is already registered."); 82 } 83 found = std::move(pyClass); 84 rawOpViewClassMap[operationName] = std::move(rawOpViewClass); 85 } 86 87 llvm::Optional<py::object> 88 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 89 py::gil_scoped_acquire(); 90 loadDialectModule(dialectNamespace); 91 // Fast match against the class map first (common case). 92 const auto foundIt = dialectClassMap.find(dialectNamespace); 93 if (foundIt != dialectClassMap.end()) { 94 if (foundIt->second.is_none()) 95 return llvm::None; 96 assert(foundIt->second && "py::object is defined"); 97 return foundIt->second; 98 } 99 100 // Not found and loading did not yield a registration. Negative cache. 101 dialectClassMap[dialectNamespace] = py::none(); 102 return llvm::None; 103 } 104 105 llvm::Optional<pybind11::object> 106 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { 107 { 108 py::gil_scoped_acquire(); 109 auto foundIt = rawOpViewClassMapCache.find(operationName); 110 if (foundIt != rawOpViewClassMapCache.end()) { 111 if (foundIt->second.is_none()) 112 return llvm::None; 113 assert(foundIt->second && "py::object is defined"); 114 return foundIt->second; 115 } 116 } 117 118 // Not found. Load the dialect namespace. 119 auto split = operationName.split('.'); 120 llvm::StringRef dialectNamespace = split.first; 121 loadDialectModule(dialectNamespace); 122 123 // Attempt to find from the canonical map and cache. 124 { 125 py::gil_scoped_acquire(); 126 auto foundIt = rawOpViewClassMap.find(operationName); 127 if (foundIt != rawOpViewClassMap.end()) { 128 if (foundIt->second.is_none()) 129 return llvm::None; 130 assert(foundIt->second && "py::object is defined"); 131 // Positive cache. 132 rawOpViewClassMapCache[operationName] = foundIt->second; 133 return foundIt->second; 134 } else { 135 // Negative cache. 136 rawOpViewClassMap[operationName] = py::none(); 137 return llvm::None; 138 } 139 } 140 } 141 142 void PyGlobals::clearImportCache() { 143 py::gil_scoped_acquire(); 144 loadedDialectModulesCache.clear(); 145 rawOpViewClassMapCache.clear(); 146 } 147