1 //===- MainModule.cpp - Main pybind module --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <tuple>
10 
11 #include "PybindUtils.h"
12 
13 #include "Globals.h"
14 #include "IRModules.h"
15 
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 // -----------------------------------------------------------------------------
21 // PyGlobals
22 // -----------------------------------------------------------------------------
23 
24 PyGlobals *PyGlobals::instance = nullptr;
25 
26 PyGlobals::PyGlobals() {
27   assert(!instance && "PyGlobals already constructed");
28   instance = this;
29 }
30 
31 PyGlobals::~PyGlobals() { instance = nullptr; }
32 
33 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
34   py::gil_scoped_acquire();
35   if (loadedDialectModulesCache.contains(dialectNamespace))
36     return;
37   // Since re-entrancy is possible, make a copy of the search prefixes.
38   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
39   py::object loaded;
40   for (std::string moduleName : localSearchPrefixes) {
41     moduleName.push_back('.');
42     moduleName.append(dialectNamespace.data(), dialectNamespace.size());
43 
44     try {
45       py::gil_scoped_release();
46       loaded = py::module::import(moduleName.c_str());
47     } catch (py::error_already_set &e) {
48       if (e.matches(PyExc_ModuleNotFoundError)) {
49         continue;
50       } else {
51         throw;
52       }
53     }
54     break;
55   }
56 
57   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
58   // may have occurred, which may do anything.
59   loadedDialectModulesCache.insert(dialectNamespace);
60 }
61 
62 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
63                                     py::object pyClass) {
64   py::gil_scoped_acquire();
65   py::object &found = dialectClassMap[dialectNamespace];
66   if (found) {
67     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
68                                              dialectNamespace +
69                                              "' is already registered.");
70   }
71   found = std::move(pyClass);
72 }
73 
74 void PyGlobals::registerOperationImpl(const std::string &operationName,
75                                       py::object pyClass,
76                                       py::object rawOpViewClass) {
77   py::gil_scoped_acquire();
78   py::object &found = operationClassMap[operationName];
79   if (found) {
80     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
81                                              operationName +
82                                              "' is already registered.");
83   }
84   found = std::move(pyClass);
85   rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
86 }
87 
88 llvm::Optional<py::object>
89 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
90   py::gil_scoped_acquire();
91   loadDialectModule(dialectNamespace);
92   // Fast match against the class map first (common case).
93   const auto foundIt = dialectClassMap.find(dialectNamespace);
94   if (foundIt != dialectClassMap.end()) {
95     if (foundIt->second.is_none())
96       return llvm::None;
97     assert(foundIt->second && "py::object is defined");
98     return foundIt->second;
99   }
100 
101   // Not found and loading did not yield a registration. Negative cache.
102   dialectClassMap[dialectNamespace] = py::none();
103   return llvm::None;
104 }
105 
106 llvm::Optional<pybind11::object>
107 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
108   {
109     py::gil_scoped_acquire();
110     auto foundIt = rawOpViewClassMapCache.find(operationName);
111     if (foundIt != rawOpViewClassMapCache.end()) {
112       if (foundIt->second.is_none())
113         return llvm::None;
114       assert(foundIt->second && "py::object is defined");
115       return foundIt->second;
116     }
117   }
118 
119   // Not found. Load the dialect namespace.
120   auto split = operationName.split('.');
121   llvm::StringRef dialectNamespace = split.first;
122   loadDialectModule(dialectNamespace);
123 
124   // Attempt to find from the canonical map and cache.
125   {
126     py::gil_scoped_acquire();
127     auto foundIt = rawOpViewClassMap.find(operationName);
128     if (foundIt != rawOpViewClassMap.end()) {
129       if (foundIt->second.is_none())
130         return llvm::None;
131       assert(foundIt->second && "py::object is defined");
132       // Positive cache.
133       rawOpViewClassMapCache[operationName] = foundIt->second;
134       return foundIt->second;
135     } else {
136       // Negative cache.
137       rawOpViewClassMap[operationName] = py::none();
138       return llvm::None;
139     }
140   }
141 }
142 
143 void PyGlobals::clearImportCache() {
144   py::gil_scoped_acquire();
145   loadedDialectModulesCache.clear();
146   rawOpViewClassMapCache.clear();
147 }
148 
149 // -----------------------------------------------------------------------------
150 // Module initialization.
151 // -----------------------------------------------------------------------------
152 
153 PYBIND11_MODULE(_mlir, m) {
154   m.doc() = "MLIR Python Native Extension";
155 
156   py::class_<PyGlobals>(m, "_Globals")
157       .def_property("dialect_search_modules",
158                     &PyGlobals::getDialectSearchPrefixes,
159                     &PyGlobals::setDialectSearchPrefixes)
160       .def("append_dialect_search_prefix",
161            [](PyGlobals &self, std::string moduleName) {
162              self.getDialectSearchPrefixes().push_back(std::move(moduleName));
163              self.clearImportCache();
164            })
165       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
166            "Testing hook for directly registering a dialect")
167       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
168            "Testing hook for directly registering an operation");
169 
170   // Aside from making the globals accessible to python, having python manage
171   // it is necessary to make sure it is destroyed (and releases its python
172   // resources) properly.
173   m.attr("globals") =
174       py::cast(new PyGlobals, py::return_value_policy::take_ownership);
175 
176   // Registration decorators.
177   m.def(
178       "register_dialect",
179       [](py::object pyClass) {
180         std::string dialectNamespace =
181             pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
182         PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
183         return pyClass;
184       },
185       "Class decorator for registering a custom Dialect wrapper");
186   m.def(
187       "register_operation",
188       [](py::object dialectClass) -> py::cpp_function {
189         return py::cpp_function(
190             [dialectClass](py::object opClass) -> py::object {
191               std::string operationName =
192                   opClass.attr("OPERATION_NAME").cast<std::string>();
193               auto rawSubclass = PyOpView::createRawSubclass(opClass);
194               PyGlobals::get().registerOperationImpl(operationName, opClass,
195                                                      rawSubclass);
196 
197               // Dict-stuff the new opClass by name onto the dialect class.
198               py::object opClassName = opClass.attr("__name__");
199               dialectClass.attr(opClassName) = opClass;
200 
201               // Now create a special "Raw" subclass that passes through
202               // construction to the OpView parent (bypasses the intermediate
203               // child's __init__).
204               opClass.attr("_Raw") = rawSubclass;
205               return opClass;
206             });
207       },
208       "Class decorator for registering a custom Operation wrapper");
209 
210   // Define and populate IR submodule.
211   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
212   populateIRSubmodule(irModule);
213 }
214