1 //===- MainModule.cpp - Main pybind module --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <tuple>
10 
11 #include "PybindUtils.h"
12 
13 #include "Globals.h"
14 #include "IRModules.h"
15 
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 // -----------------------------------------------------------------------------
21 // PyGlobals
22 // -----------------------------------------------------------------------------
23 
24 PyGlobals *PyGlobals::instance = nullptr;
25 
26 PyGlobals::PyGlobals() {
27   assert(!instance && "PyGlobals already constructed");
28   instance = this;
29 }
30 
31 PyGlobals::~PyGlobals() { instance = nullptr; }
32 
33 void PyGlobals::loadDialectModule(const std::string &dialectNamespace) {
34   if (loadedDialectModules.contains(dialectNamespace))
35     return;
36   // Since re-entrancy is possible, make a copy of the search prefixes.
37   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
38   py::object loaded;
39   for (std::string moduleName : localSearchPrefixes) {
40     moduleName.push_back('.');
41     moduleName.append(dialectNamespace);
42 
43     try {
44       loaded = py::module::import(moduleName.c_str());
45     } catch (py::error_already_set &e) {
46       if (e.matches(PyExc_ModuleNotFoundError)) {
47         continue;
48       } else {
49         throw;
50       }
51     }
52     break;
53   }
54 
55   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
56   // may have occurred, which may do anything.
57   loadedDialectModules.insert(dialectNamespace);
58 }
59 
60 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
61                                     py::object pyClass) {
62   py::object &found = dialectClassMap[dialectNamespace];
63   if (found) {
64     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
65                                              dialectNamespace +
66                                              "' is already registered.");
67   }
68   found = std::move(pyClass);
69 }
70 
71 void PyGlobals::registerOperationImpl(const std::string &operationName,
72                                       py::object pyClass, py::object rawClass) {
73   py::object &found = operationClassMap[operationName];
74   if (found) {
75     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
76                                              operationName +
77                                              "' is already registered.");
78   }
79   found = std::move(pyClass);
80   rawOperationClassMap[operationName] = std::move(rawClass);
81 }
82 
83 llvm::Optional<py::object>
84 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
85   loadDialectModule(dialectNamespace);
86   // Fast match against the class map first (common case).
87   const auto foundIt = dialectClassMap.find(dialectNamespace);
88   if (foundIt != dialectClassMap.end()) {
89     if (foundIt->second.is_none())
90       return llvm::None;
91     assert(foundIt->second && "py::object is defined");
92     return foundIt->second;
93   }
94 
95   // Not found and loading did not yield a registration. Negative cache.
96   dialectClassMap[dialectNamespace] = py::none();
97   return llvm::None;
98 }
99 
100 // -----------------------------------------------------------------------------
101 // Module initialization.
102 // -----------------------------------------------------------------------------
103 
104 PYBIND11_MODULE(_mlir, m) {
105   m.doc() = "MLIR Python Native Extension";
106 
107   py::class_<PyGlobals>(m, "_Globals")
108       .def_property("dialect_search_modules",
109                     &PyGlobals::getDialectSearchPrefixes,
110                     &PyGlobals::setDialectSearchPrefixes)
111       .def("append_dialect_search_prefix",
112            [](PyGlobals &self, std::string moduleName) {
113              self.getDialectSearchPrefixes().push_back(std::move(moduleName));
114            })
115       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
116            "Testing hook for directly registering a dialect")
117       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
118            "Testing hook for directly registering an operation");
119 
120   // Aside from making the globals accessible to python, having python manage
121   // it is necessary to make sure it is destroyed (and releases its python
122   // resources) properly.
123   m.attr("globals") =
124       py::cast(new PyGlobals, py::return_value_policy::take_ownership);
125 
126   // Registration decorators.
127   m.def(
128       "register_dialect",
129       [](py::object pyClass) {
130         std::string dialectNamespace =
131             pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
132         PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
133         return pyClass;
134       },
135       "Class decorator for registering a custom Dialect wrapper");
136   m.def(
137       "register_operation",
138       [](py::object dialectClass) -> py::cpp_function {
139         return py::cpp_function(
140             [dialectClass](py::object opClass) -> py::object {
141               std::string operationName =
142                   opClass.attr("OPERATION_NAME").cast<std::string>();
143               auto rawSubclass = PyOpView::createRawSubclass(opClass);
144               PyGlobals::get().registerOperationImpl(operationName, opClass,
145                                                      rawSubclass);
146 
147               // Dict-stuff the new opClass by name onto the dialect class.
148               py::object opClassName = opClass.attr("__name__");
149               dialectClass.attr(opClassName) = opClass;
150 
151               // Now create a special "Raw" subclass that passes through
152               // construction to the OpView parent (bypasses the intermediate
153               // child's __init__).
154               opClass.attr("_Raw") = rawSubclass;
155               return opClass;
156             });
157       },
158       "Class decorator for registering a custom Operation wrapper");
159 
160   // Define and populate IR submodule.
161   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
162   populateIRSubmodule(irModule);
163 }
164