1 //===- IRModule.cpp - IR pybind module ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 #include "Globals.h"
11 #include "PybindUtils.h"
12 
13 #include <vector>
14 
15 #include "mlir-c/Bindings/Python/Interop.h"
16 
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20 
21 // -----------------------------------------------------------------------------
22 // PyGlobals
23 // -----------------------------------------------------------------------------
24 
25 PyGlobals *PyGlobals::instance = nullptr;
26 
27 PyGlobals::PyGlobals() {
28   assert(!instance && "PyGlobals already constructed");
29   instance = this;
30   // The default search path include {mlir.}dialects, where {mlir.} is the
31   // package prefix configured at compile time.
32   dialectSearchPrefixes.push_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
33 }
34 
35 PyGlobals::~PyGlobals() { instance = nullptr; }
36 
37 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
38   py::gil_scoped_acquire();
39   if (loadedDialectModulesCache.contains(dialectNamespace))
40     return;
41   // Since re-entrancy is possible, make a copy of the search prefixes.
42   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
43   py::object loaded;
44   for (std::string moduleName : localSearchPrefixes) {
45     moduleName.push_back('.');
46     moduleName.append(dialectNamespace.data(), dialectNamespace.size());
47 
48     try {
49       py::gil_scoped_release();
50       loaded = py::module::import(moduleName.c_str());
51     } catch (py::error_already_set &e) {
52       if (e.matches(PyExc_ModuleNotFoundError)) {
53         continue;
54       } else {
55         throw;
56       }
57     }
58     break;
59   }
60 
61   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
62   // may have occurred, which may do anything.
63   loadedDialectModulesCache.insert(dialectNamespace);
64 }
65 
66 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
67                                     py::object pyClass) {
68   py::gil_scoped_acquire();
69   py::object &found = dialectClassMap[dialectNamespace];
70   if (found) {
71     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
72                                              dialectNamespace +
73                                              "' is already registered.");
74   }
75   found = std::move(pyClass);
76 }
77 
78 void PyGlobals::registerOperationImpl(const std::string &operationName,
79                                       py::object pyClass,
80                                       py::object rawOpViewClass) {
81   py::gil_scoped_acquire();
82   py::object &found = operationClassMap[operationName];
83   if (found) {
84     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
85                                              operationName +
86                                              "' is already registered.");
87   }
88   found = std::move(pyClass);
89   rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
90 }
91 
92 llvm::Optional<py::object>
93 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
94   py::gil_scoped_acquire();
95   loadDialectModule(dialectNamespace);
96   // Fast match against the class map first (common case).
97   const auto foundIt = dialectClassMap.find(dialectNamespace);
98   if (foundIt != dialectClassMap.end()) {
99     if (foundIt->second.is_none())
100       return llvm::None;
101     assert(foundIt->second && "py::object is defined");
102     return foundIt->second;
103   }
104 
105   // Not found and loading did not yield a registration. Negative cache.
106   dialectClassMap[dialectNamespace] = py::none();
107   return llvm::None;
108 }
109 
110 llvm::Optional<pybind11::object>
111 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
112   {
113     py::gil_scoped_acquire();
114     auto foundIt = rawOpViewClassMapCache.find(operationName);
115     if (foundIt != rawOpViewClassMapCache.end()) {
116       if (foundIt->second.is_none())
117         return llvm::None;
118       assert(foundIt->second && "py::object is defined");
119       return foundIt->second;
120     }
121   }
122 
123   // Not found. Load the dialect namespace.
124   auto split = operationName.split('.');
125   llvm::StringRef dialectNamespace = split.first;
126   loadDialectModule(dialectNamespace);
127 
128   // Attempt to find from the canonical map and cache.
129   {
130     py::gil_scoped_acquire();
131     auto foundIt = rawOpViewClassMap.find(operationName);
132     if (foundIt != rawOpViewClassMap.end()) {
133       if (foundIt->second.is_none())
134         return llvm::None;
135       assert(foundIt->second && "py::object is defined");
136       // Positive cache.
137       rawOpViewClassMapCache[operationName] = foundIt->second;
138       return foundIt->second;
139     } else {
140       // Negative cache.
141       rawOpViewClassMap[operationName] = py::none();
142       return llvm::None;
143     }
144   }
145 }
146 
147 void PyGlobals::clearImportCache() {
148   py::gil_scoped_acquire();
149   loadedDialectModulesCache.clear();
150   rawOpViewClassMapCache.clear();
151 }
152