1*ac0a70f3SAlex Zinenko //===- IRModule.cpp - IR pybind module ------------------------------------===// 2*ac0a70f3SAlex Zinenko // 3*ac0a70f3SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*ac0a70f3SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5*ac0a70f3SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*ac0a70f3SAlex Zinenko // 7*ac0a70f3SAlex Zinenko //===----------------------------------------------------------------------===// 8*ac0a70f3SAlex Zinenko 9*ac0a70f3SAlex Zinenko #include "IRModule.h" 10*ac0a70f3SAlex Zinenko #include "Globals.h" 11*ac0a70f3SAlex Zinenko #include "PybindUtils.h" 12*ac0a70f3SAlex Zinenko 13*ac0a70f3SAlex Zinenko #include <vector> 14*ac0a70f3SAlex Zinenko 15*ac0a70f3SAlex Zinenko namespace py = pybind11; 16*ac0a70f3SAlex Zinenko using namespace mlir; 17*ac0a70f3SAlex Zinenko using namespace mlir::python; 18*ac0a70f3SAlex Zinenko 19*ac0a70f3SAlex Zinenko // ----------------------------------------------------------------------------- 20*ac0a70f3SAlex Zinenko // PyGlobals 21*ac0a70f3SAlex Zinenko // ----------------------------------------------------------------------------- 22*ac0a70f3SAlex Zinenko 23*ac0a70f3SAlex Zinenko PyGlobals *PyGlobals::instance = nullptr; 24*ac0a70f3SAlex Zinenko 25*ac0a70f3SAlex Zinenko PyGlobals::PyGlobals() { 26*ac0a70f3SAlex Zinenko assert(!instance && "PyGlobals already constructed"); 27*ac0a70f3SAlex Zinenko instance = this; 28*ac0a70f3SAlex Zinenko } 29*ac0a70f3SAlex Zinenko 30*ac0a70f3SAlex Zinenko PyGlobals::~PyGlobals() { instance = nullptr; } 31*ac0a70f3SAlex Zinenko 32*ac0a70f3SAlex Zinenko void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { 33*ac0a70f3SAlex Zinenko py::gil_scoped_acquire(); 34*ac0a70f3SAlex Zinenko if (loadedDialectModulesCache.contains(dialectNamespace)) 35*ac0a70f3SAlex Zinenko return; 36*ac0a70f3SAlex Zinenko // Since re-entrancy is possible, make a copy of the search prefixes. 37*ac0a70f3SAlex Zinenko std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; 38*ac0a70f3SAlex Zinenko py::object loaded; 39*ac0a70f3SAlex Zinenko for (std::string moduleName : localSearchPrefixes) { 40*ac0a70f3SAlex Zinenko moduleName.push_back('.'); 41*ac0a70f3SAlex Zinenko moduleName.append(dialectNamespace.data(), dialectNamespace.size()); 42*ac0a70f3SAlex Zinenko 43*ac0a70f3SAlex Zinenko try { 44*ac0a70f3SAlex Zinenko py::gil_scoped_release(); 45*ac0a70f3SAlex Zinenko loaded = py::module::import(moduleName.c_str()); 46*ac0a70f3SAlex Zinenko } catch (py::error_already_set &e) { 47*ac0a70f3SAlex Zinenko if (e.matches(PyExc_ModuleNotFoundError)) { 48*ac0a70f3SAlex Zinenko continue; 49*ac0a70f3SAlex Zinenko } else { 50*ac0a70f3SAlex Zinenko throw; 51*ac0a70f3SAlex Zinenko } 52*ac0a70f3SAlex Zinenko } 53*ac0a70f3SAlex Zinenko break; 54*ac0a70f3SAlex Zinenko } 55*ac0a70f3SAlex Zinenko 56*ac0a70f3SAlex Zinenko // Note: Iterator cannot be shared from prior to loading, since re-entrancy 57*ac0a70f3SAlex Zinenko // may have occurred, which may do anything. 58*ac0a70f3SAlex Zinenko loadedDialectModulesCache.insert(dialectNamespace); 59*ac0a70f3SAlex Zinenko } 60*ac0a70f3SAlex Zinenko 61*ac0a70f3SAlex Zinenko void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, 62*ac0a70f3SAlex Zinenko py::object pyClass) { 63*ac0a70f3SAlex Zinenko py::gil_scoped_acquire(); 64*ac0a70f3SAlex Zinenko py::object &found = dialectClassMap[dialectNamespace]; 65*ac0a70f3SAlex Zinenko if (found) { 66*ac0a70f3SAlex Zinenko throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + 67*ac0a70f3SAlex Zinenko dialectNamespace + 68*ac0a70f3SAlex Zinenko "' is already registered."); 69*ac0a70f3SAlex Zinenko } 70*ac0a70f3SAlex Zinenko found = std::move(pyClass); 71*ac0a70f3SAlex Zinenko } 72*ac0a70f3SAlex Zinenko 73*ac0a70f3SAlex Zinenko void PyGlobals::registerOperationImpl(const std::string &operationName, 74*ac0a70f3SAlex Zinenko py::object pyClass, 75*ac0a70f3SAlex Zinenko py::object rawOpViewClass) { 76*ac0a70f3SAlex Zinenko py::gil_scoped_acquire(); 77*ac0a70f3SAlex Zinenko py::object &found = operationClassMap[operationName]; 78*ac0a70f3SAlex Zinenko if (found) { 79*ac0a70f3SAlex Zinenko throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + 80*ac0a70f3SAlex Zinenko operationName + 81*ac0a70f3SAlex Zinenko "' is already registered."); 82*ac0a70f3SAlex Zinenko } 83*ac0a70f3SAlex Zinenko found = std::move(pyClass); 84*ac0a70f3SAlex Zinenko rawOpViewClassMap[operationName] = std::move(rawOpViewClass); 85*ac0a70f3SAlex Zinenko } 86*ac0a70f3SAlex Zinenko 87*ac0a70f3SAlex Zinenko llvm::Optional<py::object> 88*ac0a70f3SAlex Zinenko PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { 89*ac0a70f3SAlex Zinenko py::gil_scoped_acquire(); 90*ac0a70f3SAlex Zinenko loadDialectModule(dialectNamespace); 91*ac0a70f3SAlex Zinenko // Fast match against the class map first (common case). 92*ac0a70f3SAlex Zinenko const auto foundIt = dialectClassMap.find(dialectNamespace); 93*ac0a70f3SAlex Zinenko if (foundIt != dialectClassMap.end()) { 94*ac0a70f3SAlex Zinenko if (foundIt->second.is_none()) 95*ac0a70f3SAlex Zinenko return llvm::None; 96*ac0a70f3SAlex Zinenko assert(foundIt->second && "py::object is defined"); 97*ac0a70f3SAlex Zinenko return foundIt->second; 98*ac0a70f3SAlex Zinenko } 99*ac0a70f3SAlex Zinenko 100*ac0a70f3SAlex Zinenko // Not found and loading did not yield a registration. Negative cache. 101*ac0a70f3SAlex Zinenko dialectClassMap[dialectNamespace] = py::none(); 102*ac0a70f3SAlex Zinenko return llvm::None; 103*ac0a70f3SAlex Zinenko } 104*ac0a70f3SAlex Zinenko 105*ac0a70f3SAlex Zinenko llvm::Optional<pybind11::object> 106*ac0a70f3SAlex Zinenko PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { 107*ac0a70f3SAlex Zinenko { 108*ac0a70f3SAlex Zinenko py::gil_scoped_acquire(); 109*ac0a70f3SAlex Zinenko auto foundIt = rawOpViewClassMapCache.find(operationName); 110*ac0a70f3SAlex Zinenko if (foundIt != rawOpViewClassMapCache.end()) { 111*ac0a70f3SAlex Zinenko if (foundIt->second.is_none()) 112*ac0a70f3SAlex Zinenko return llvm::None; 113*ac0a70f3SAlex Zinenko assert(foundIt->second && "py::object is defined"); 114*ac0a70f3SAlex Zinenko return foundIt->second; 115*ac0a70f3SAlex Zinenko } 116*ac0a70f3SAlex Zinenko } 117*ac0a70f3SAlex Zinenko 118*ac0a70f3SAlex Zinenko // Not found. Load the dialect namespace. 119*ac0a70f3SAlex Zinenko auto split = operationName.split('.'); 120*ac0a70f3SAlex Zinenko llvm::StringRef dialectNamespace = split.first; 121*ac0a70f3SAlex Zinenko loadDialectModule(dialectNamespace); 122*ac0a70f3SAlex Zinenko 123*ac0a70f3SAlex Zinenko // Attempt to find from the canonical map and cache. 124*ac0a70f3SAlex Zinenko { 125*ac0a70f3SAlex Zinenko py::gil_scoped_acquire(); 126*ac0a70f3SAlex Zinenko auto foundIt = rawOpViewClassMap.find(operationName); 127*ac0a70f3SAlex Zinenko if (foundIt != rawOpViewClassMap.end()) { 128*ac0a70f3SAlex Zinenko if (foundIt->second.is_none()) 129*ac0a70f3SAlex Zinenko return llvm::None; 130*ac0a70f3SAlex Zinenko assert(foundIt->second && "py::object is defined"); 131*ac0a70f3SAlex Zinenko // Positive cache. 132*ac0a70f3SAlex Zinenko rawOpViewClassMapCache[operationName] = foundIt->second; 133*ac0a70f3SAlex Zinenko return foundIt->second; 134*ac0a70f3SAlex Zinenko } else { 135*ac0a70f3SAlex Zinenko // Negative cache. 136*ac0a70f3SAlex Zinenko rawOpViewClassMap[operationName] = py::none(); 137*ac0a70f3SAlex Zinenko return llvm::None; 138*ac0a70f3SAlex Zinenko } 139*ac0a70f3SAlex Zinenko } 140*ac0a70f3SAlex Zinenko } 141*ac0a70f3SAlex Zinenko 142*ac0a70f3SAlex Zinenko void PyGlobals::clearImportCache() { 143*ac0a70f3SAlex Zinenko py::gil_scoped_acquire(); 144*ac0a70f3SAlex Zinenko loadedDialectModulesCache.clear(); 145*ac0a70f3SAlex Zinenko rawOpViewClassMapCache.clear(); 146*ac0a70f3SAlex Zinenko } 147