1 //===- IRModule.cpp - IR pybind module ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "IRModule.h"
10 #include "Globals.h"
11 #include "PybindUtils.h"
12
13 #include <vector>
14
15 #include "mlir-c/Bindings/Python/Interop.h"
16
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20
21 // -----------------------------------------------------------------------------
22 // PyGlobals
23 // -----------------------------------------------------------------------------
24
25 PyGlobals *PyGlobals::instance = nullptr;
26
PyGlobals()27 PyGlobals::PyGlobals() {
28 assert(!instance && "PyGlobals already constructed");
29 instance = this;
30 // The default search path include {mlir.}dialects, where {mlir.} is the
31 // package prefix configured at compile time.
32 dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
33 }
34
~PyGlobals()35 PyGlobals::~PyGlobals() { instance = nullptr; }
36
loadDialectModule(llvm::StringRef dialectNamespace)37 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
38 if (loadedDialectModulesCache.contains(dialectNamespace))
39 return;
40 // Since re-entrancy is possible, make a copy of the search prefixes.
41 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
42 py::object loaded;
43 for (std::string moduleName : localSearchPrefixes) {
44 moduleName.push_back('.');
45 moduleName.append(dialectNamespace.data(), dialectNamespace.size());
46
47 try {
48 loaded = py::module::import(moduleName.c_str());
49 } catch (py::error_already_set &e) {
50 if (e.matches(PyExc_ModuleNotFoundError)) {
51 continue;
52 }
53 throw;
54 }
55 break;
56 }
57
58 // Note: Iterator cannot be shared from prior to loading, since re-entrancy
59 // may have occurred, which may do anything.
60 loadedDialectModulesCache.insert(dialectNamespace);
61 }
62
registerDialectImpl(const std::string & dialectNamespace,py::object pyClass)63 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
64 py::object pyClass) {
65 py::object &found = dialectClassMap[dialectNamespace];
66 if (found) {
67 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
68 dialectNamespace +
69 "' is already registered.");
70 }
71 found = std::move(pyClass);
72 }
73
registerOperationImpl(const std::string & operationName,py::object pyClass,py::object rawOpViewClass)74 void PyGlobals::registerOperationImpl(const std::string &operationName,
75 py::object pyClass,
76 py::object rawOpViewClass) {
77 py::object &found = operationClassMap[operationName];
78 if (found) {
79 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
80 operationName +
81 "' is already registered.");
82 }
83 found = std::move(pyClass);
84 rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
85 }
86
87 llvm::Optional<py::object>
lookupDialectClass(const std::string & dialectNamespace)88 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
89 loadDialectModule(dialectNamespace);
90 // Fast match against the class map first (common case).
91 const auto foundIt = dialectClassMap.find(dialectNamespace);
92 if (foundIt != dialectClassMap.end()) {
93 if (foundIt->second.is_none())
94 return llvm::None;
95 assert(foundIt->second && "py::object is defined");
96 return foundIt->second;
97 }
98
99 // Not found and loading did not yield a registration. Negative cache.
100 dialectClassMap[dialectNamespace] = py::none();
101 return llvm::None;
102 }
103
104 llvm::Optional<pybind11::object>
lookupRawOpViewClass(llvm::StringRef operationName)105 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
106 {
107 auto foundIt = rawOpViewClassMapCache.find(operationName);
108 if (foundIt != rawOpViewClassMapCache.end()) {
109 if (foundIt->second.is_none())
110 return llvm::None;
111 assert(foundIt->second && "py::object is defined");
112 return foundIt->second;
113 }
114 }
115
116 // Not found. Load the dialect namespace.
117 auto split = operationName.split('.');
118 llvm::StringRef dialectNamespace = split.first;
119 loadDialectModule(dialectNamespace);
120
121 // Attempt to find from the canonical map and cache.
122 {
123 auto foundIt = rawOpViewClassMap.find(operationName);
124 if (foundIt != rawOpViewClassMap.end()) {
125 if (foundIt->second.is_none())
126 return llvm::None;
127 assert(foundIt->second && "py::object is defined");
128 // Positive cache.
129 rawOpViewClassMapCache[operationName] = foundIt->second;
130 return foundIt->second;
131 }
132 // Negative cache.
133 rawOpViewClassMap[operationName] = py::none();
134 return llvm::None;
135 }
136 }
137
clearImportCache()138 void PyGlobals::clearImportCache() {
139 loadedDialectModulesCache.clear();
140 rawOpViewClassMapCache.clear();
141 }
142