1ac0a70f3SAlex Zinenko //===- IRModule.cpp - IR pybind module ------------------------------------===// 2ac0a70f3SAlex Zinenko // 3ac0a70f3SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4ac0a70f3SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5ac0a70f3SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6ac0a70f3SAlex Zinenko // 7ac0a70f3SAlex Zinenko //===----------------------------------------------------------------------===// 8ac0a70f3SAlex Zinenko 9ac0a70f3SAlex Zinenko #include "IRModule.h" 10ac0a70f3SAlex Zinenko #include "Globals.h" 11ac0a70f3SAlex Zinenko #include "PybindUtils.h" 12ac0a70f3SAlex Zinenko 13ac0a70f3SAlex Zinenko #include <vector> 14ac0a70f3SAlex Zinenko 15*cb7b0381SStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h" 16*cb7b0381SStella Laurenzo 17ac0a70f3SAlex Zinenko namespace py = pybind11; 18ac0a70f3SAlex Zinenko using namespace mlir; 19ac0a70f3SAlex Zinenko using namespace mlir::python; 20ac0a70f3SAlex Zinenko 21ac0a70f3SAlex Zinenko // ----------------------------------------------------------------------------- 22ac0a70f3SAlex Zinenko // PyGlobals 23ac0a70f3SAlex Zinenko // ----------------------------------------------------------------------------- 24ac0a70f3SAlex Zinenko 25ac0a70f3SAlex Zinenko PyGlobals *PyGlobals::instance = nullptr; 26ac0a70f3SAlex Zinenko 27ac0a70f3SAlex Zinenko PyGlobals::PyGlobals() { 28ac0a70f3SAlex Zinenko assert(!instance && "PyGlobals already constructed"); 29ac0a70f3SAlex Zinenko instance = this; 30*cb7b0381SStella Laurenzo // The default search path include {mlir.}dialects, where {mlir.} is the 31*cb7b0381SStella Laurenzo // package prefix configured at compile time. 32*cb7b0381SStella Laurenzo dialectSearchPrefixes.push_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); 33ac0a70f3SAlex Zinenko } 34ac0a70f3SAlex Zinenko 35ac0a70f3SAlex Zinenko PyGlobals::~PyGlobals() { instance = nullptr; } 36ac0a70f3SAlex Zinenko 37ac0a70f3SAlex Zinenko void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 38ac0a70f3SAlex Zinenko py::gil_scoped_acquire(); 39ac0a70f3SAlex Zinenko if (loadedDialectModulesCache.contains(dialectNamespace)) 40ac0a70f3SAlex Zinenko return; 41ac0a70f3SAlex Zinenko // Since re-entrancy is possible, make a copy of the search prefixes. 42ac0a70f3SAlex Zinenko std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 43ac0a70f3SAlex Zinenko py::object loaded; 44ac0a70f3SAlex Zinenko for (std::string moduleName : localSearchPrefixes) { 45ac0a70f3SAlex Zinenko moduleName.push_back('.'); 46ac0a70f3SAlex Zinenko moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 47ac0a70f3SAlex Zinenko 48ac0a70f3SAlex Zinenko try { 49ac0a70f3SAlex Zinenko py::gil_scoped_release(); 50ac0a70f3SAlex Zinenko loaded = py::module::import(moduleName.c_str()); 51ac0a70f3SAlex Zinenko } catch (py::error_already_set &e) { 52ac0a70f3SAlex Zinenko if (e.matches(PyExc_ModuleNotFoundError)) { 53ac0a70f3SAlex Zinenko continue; 54ac0a70f3SAlex Zinenko } else { 55ac0a70f3SAlex Zinenko throw; 56ac0a70f3SAlex Zinenko } 57ac0a70f3SAlex Zinenko } 58ac0a70f3SAlex Zinenko break; 59ac0a70f3SAlex Zinenko } 60ac0a70f3SAlex Zinenko 61ac0a70f3SAlex Zinenko // Note: Iterator cannot be shared from prior to loading, since re-entrancy 62ac0a70f3SAlex Zinenko // may have occurred, which may do anything. 63ac0a70f3SAlex Zinenko loadedDialectModulesCache.insert(dialectNamespace); 64ac0a70f3SAlex Zinenko } 65ac0a70f3SAlex Zinenko 66ac0a70f3SAlex Zinenko void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 67ac0a70f3SAlex Zinenko py::object pyClass) { 68ac0a70f3SAlex Zinenko py::gil_scoped_acquire(); 69ac0a70f3SAlex Zinenko py::object &found = dialectClassMap[dialectNamespace]; 70ac0a70f3SAlex Zinenko if (found) { 71ac0a70f3SAlex Zinenko throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + 72ac0a70f3SAlex Zinenko dialectNamespace + 73ac0a70f3SAlex Zinenko "' is already registered."); 74ac0a70f3SAlex Zinenko } 75ac0a70f3SAlex Zinenko found = std::move(pyClass); 76ac0a70f3SAlex Zinenko } 77ac0a70f3SAlex Zinenko 78ac0a70f3SAlex Zinenko void PyGlobals::registerOperationImpl(const std::string &operationName, 79ac0a70f3SAlex Zinenko py::object pyClass, 80ac0a70f3SAlex Zinenko py::object rawOpViewClass) { 81ac0a70f3SAlex Zinenko py::gil_scoped_acquire(); 82ac0a70f3SAlex Zinenko py::object &found = operationClassMap[operationName]; 83ac0a70f3SAlex Zinenko if (found) { 84ac0a70f3SAlex Zinenko throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + 85ac0a70f3SAlex Zinenko operationName + 86ac0a70f3SAlex Zinenko "' is already registered."); 87ac0a70f3SAlex Zinenko } 88ac0a70f3SAlex Zinenko found = std::move(pyClass); 89ac0a70f3SAlex Zinenko rawOpViewClassMap[operationName] = std::move(rawOpViewClass); 90ac0a70f3SAlex Zinenko } 91ac0a70f3SAlex Zinenko 92ac0a70f3SAlex Zinenko llvm::Optional<py::object> 93ac0a70f3SAlex Zinenko PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 94ac0a70f3SAlex Zinenko py::gil_scoped_acquire(); 95ac0a70f3SAlex Zinenko loadDialectModule(dialectNamespace); 96ac0a70f3SAlex Zinenko // Fast match against the class map first (common case). 97ac0a70f3SAlex Zinenko const auto foundIt = dialectClassMap.find(dialectNamespace); 98ac0a70f3SAlex Zinenko if (foundIt != dialectClassMap.end()) { 99ac0a70f3SAlex Zinenko if (foundIt->second.is_none()) 100ac0a70f3SAlex Zinenko return llvm::None; 101ac0a70f3SAlex Zinenko assert(foundIt->second && "py::object is defined"); 102ac0a70f3SAlex Zinenko return foundIt->second; 103ac0a70f3SAlex Zinenko } 104ac0a70f3SAlex Zinenko 105ac0a70f3SAlex Zinenko // Not found and loading did not yield a registration. Negative cache. 106ac0a70f3SAlex Zinenko dialectClassMap[dialectNamespace] = py::none(); 107ac0a70f3SAlex Zinenko return llvm::None; 108ac0a70f3SAlex Zinenko } 109ac0a70f3SAlex Zinenko 110ac0a70f3SAlex Zinenko llvm::Optional<pybind11::object> 111ac0a70f3SAlex Zinenko PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { 112ac0a70f3SAlex Zinenko { 113ac0a70f3SAlex Zinenko py::gil_scoped_acquire(); 114ac0a70f3SAlex Zinenko auto foundIt = rawOpViewClassMapCache.find(operationName); 115ac0a70f3SAlex Zinenko if (foundIt != rawOpViewClassMapCache.end()) { 116ac0a70f3SAlex Zinenko if (foundIt->second.is_none()) 117ac0a70f3SAlex Zinenko return llvm::None; 118ac0a70f3SAlex Zinenko assert(foundIt->second && "py::object is defined"); 119ac0a70f3SAlex Zinenko return foundIt->second; 120ac0a70f3SAlex Zinenko } 121ac0a70f3SAlex Zinenko } 122ac0a70f3SAlex Zinenko 123ac0a70f3SAlex Zinenko // Not found. Load the dialect namespace. 124ac0a70f3SAlex Zinenko auto split = operationName.split('.'); 125ac0a70f3SAlex Zinenko llvm::StringRef dialectNamespace = split.first; 126ac0a70f3SAlex Zinenko loadDialectModule(dialectNamespace); 127ac0a70f3SAlex Zinenko 128ac0a70f3SAlex Zinenko // Attempt to find from the canonical map and cache. 129ac0a70f3SAlex Zinenko { 130ac0a70f3SAlex Zinenko py::gil_scoped_acquire(); 131ac0a70f3SAlex Zinenko auto foundIt = rawOpViewClassMap.find(operationName); 132ac0a70f3SAlex Zinenko if (foundIt != rawOpViewClassMap.end()) { 133ac0a70f3SAlex Zinenko if (foundIt->second.is_none()) 134ac0a70f3SAlex Zinenko return llvm::None; 135ac0a70f3SAlex Zinenko assert(foundIt->second && "py::object is defined"); 136ac0a70f3SAlex Zinenko // Positive cache. 137ac0a70f3SAlex Zinenko rawOpViewClassMapCache[operationName] = foundIt->second; 138ac0a70f3SAlex Zinenko return foundIt->second; 139ac0a70f3SAlex Zinenko } else { 140ac0a70f3SAlex Zinenko // Negative cache. 141ac0a70f3SAlex Zinenko rawOpViewClassMap[operationName] = py::none(); 142ac0a70f3SAlex Zinenko return llvm::None; 143ac0a70f3SAlex Zinenko } 144ac0a70f3SAlex Zinenko } 145ac0a70f3SAlex Zinenko } 146ac0a70f3SAlex Zinenko 147ac0a70f3SAlex Zinenko void PyGlobals::clearImportCache() { 148ac0a70f3SAlex Zinenko py::gil_scoped_acquire(); 149ac0a70f3SAlex Zinenko loadedDialectModulesCache.clear(); 150ac0a70f3SAlex Zinenko rawOpViewClassMapCache.clear(); 151ac0a70f3SAlex Zinenko } 152