1722475a3SStella Laurenzo //===- MainModule.cpp - Main pybind module --------------------------------===//
2722475a3SStella Laurenzo //
3722475a3SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4722475a3SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5722475a3SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6722475a3SStella Laurenzo //
7722475a3SStella Laurenzo //===----------------------------------------------------------------------===//
8722475a3SStella Laurenzo 
9722475a3SStella Laurenzo #include <tuple>
10722475a3SStella Laurenzo 
11013b9322SStella Laurenzo #include "PybindUtils.h"
12722475a3SStella Laurenzo 
13013b9322SStella Laurenzo #include "Globals.h"
14fcd2969dSzhanghb97 #include "IRModules.h"
15722475a3SStella Laurenzo 
1695b77f2eSStella Laurenzo namespace py = pybind11;
17722475a3SStella Laurenzo using namespace mlir;
1895b77f2eSStella Laurenzo using namespace mlir::python;
19722475a3SStella Laurenzo 
20013b9322SStella Laurenzo // -----------------------------------------------------------------------------
21013b9322SStella Laurenzo // PyGlobals
22013b9322SStella Laurenzo // -----------------------------------------------------------------------------
23013b9322SStella Laurenzo 
24013b9322SStella Laurenzo PyGlobals *PyGlobals::instance = nullptr;
25013b9322SStella Laurenzo 
26013b9322SStella Laurenzo PyGlobals::PyGlobals() {
27013b9322SStella Laurenzo   assert(!instance && "PyGlobals already constructed");
28013b9322SStella Laurenzo   instance = this;
29013b9322SStella Laurenzo }
30013b9322SStella Laurenzo 
31013b9322SStella Laurenzo PyGlobals::~PyGlobals() { instance = nullptr; }
32013b9322SStella Laurenzo 
33*8260db75SStella Laurenzo void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
34*8260db75SStella Laurenzo   py::gil_scoped_acquire();
35*8260db75SStella Laurenzo   if (loadedDialectModulesCache.contains(dialectNamespace))
36013b9322SStella Laurenzo     return;
37013b9322SStella Laurenzo   // Since re-entrancy is possible, make a copy of the search prefixes.
38013b9322SStella Laurenzo   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
39013b9322SStella Laurenzo   py::object loaded;
40013b9322SStella Laurenzo   for (std::string moduleName : localSearchPrefixes) {
41013b9322SStella Laurenzo     moduleName.push_back('.');
42*8260db75SStella Laurenzo     moduleName.append(dialectNamespace.data(), dialectNamespace.size());
43013b9322SStella Laurenzo 
44013b9322SStella Laurenzo     try {
45*8260db75SStella Laurenzo       py::gil_scoped_release();
46013b9322SStella Laurenzo       loaded = py::module::import(moduleName.c_str());
47013b9322SStella Laurenzo     } catch (py::error_already_set &e) {
48013b9322SStella Laurenzo       if (e.matches(PyExc_ModuleNotFoundError)) {
49013b9322SStella Laurenzo         continue;
50013b9322SStella Laurenzo       } else {
51013b9322SStella Laurenzo         throw;
52013b9322SStella Laurenzo       }
53013b9322SStella Laurenzo     }
54013b9322SStella Laurenzo     break;
55013b9322SStella Laurenzo   }
56013b9322SStella Laurenzo 
57013b9322SStella Laurenzo   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
58013b9322SStella Laurenzo   // may have occurred, which may do anything.
59*8260db75SStella Laurenzo   loadedDialectModulesCache.insert(dialectNamespace);
60013b9322SStella Laurenzo }
61013b9322SStella Laurenzo 
62013b9322SStella Laurenzo void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
63013b9322SStella Laurenzo                                     py::object pyClass) {
64*8260db75SStella Laurenzo   py::gil_scoped_acquire();
65013b9322SStella Laurenzo   py::object &found = dialectClassMap[dialectNamespace];
66013b9322SStella Laurenzo   if (found) {
67013b9322SStella Laurenzo     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
68013b9322SStella Laurenzo                                              dialectNamespace +
69013b9322SStella Laurenzo                                              "' is already registered.");
70013b9322SStella Laurenzo   }
71013b9322SStella Laurenzo   found = std::move(pyClass);
72013b9322SStella Laurenzo }
73013b9322SStella Laurenzo 
74013b9322SStella Laurenzo void PyGlobals::registerOperationImpl(const std::string &operationName,
75*8260db75SStella Laurenzo                                       py::object pyClass,
76*8260db75SStella Laurenzo                                       py::object rawOpViewClass) {
77*8260db75SStella Laurenzo   py::gil_scoped_acquire();
78013b9322SStella Laurenzo   py::object &found = operationClassMap[operationName];
79013b9322SStella Laurenzo   if (found) {
80013b9322SStella Laurenzo     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
81013b9322SStella Laurenzo                                              operationName +
82013b9322SStella Laurenzo                                              "' is already registered.");
83013b9322SStella Laurenzo   }
84013b9322SStella Laurenzo   found = std::move(pyClass);
85*8260db75SStella Laurenzo   rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
86013b9322SStella Laurenzo }
87013b9322SStella Laurenzo 
88013b9322SStella Laurenzo llvm::Optional<py::object>
89013b9322SStella Laurenzo PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
90*8260db75SStella Laurenzo   py::gil_scoped_acquire();
91013b9322SStella Laurenzo   loadDialectModule(dialectNamespace);
92013b9322SStella Laurenzo   // Fast match against the class map first (common case).
93013b9322SStella Laurenzo   const auto foundIt = dialectClassMap.find(dialectNamespace);
94013b9322SStella Laurenzo   if (foundIt != dialectClassMap.end()) {
95013b9322SStella Laurenzo     if (foundIt->second.is_none())
96013b9322SStella Laurenzo       return llvm::None;
97013b9322SStella Laurenzo     assert(foundIt->second && "py::object is defined");
98013b9322SStella Laurenzo     return foundIt->second;
99013b9322SStella Laurenzo   }
100013b9322SStella Laurenzo 
101013b9322SStella Laurenzo   // Not found and loading did not yield a registration. Negative cache.
102013b9322SStella Laurenzo   dialectClassMap[dialectNamespace] = py::none();
103013b9322SStella Laurenzo   return llvm::None;
104013b9322SStella Laurenzo }
105013b9322SStella Laurenzo 
106*8260db75SStella Laurenzo llvm::Optional<pybind11::object>
107*8260db75SStella Laurenzo PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
108*8260db75SStella Laurenzo   {
109*8260db75SStella Laurenzo     py::gil_scoped_acquire();
110*8260db75SStella Laurenzo     auto foundIt = rawOpViewClassMapCache.find(operationName);
111*8260db75SStella Laurenzo     if (foundIt != rawOpViewClassMapCache.end()) {
112*8260db75SStella Laurenzo       if (foundIt->second.is_none())
113*8260db75SStella Laurenzo         return llvm::None;
114*8260db75SStella Laurenzo       assert(foundIt->second && "py::object is defined");
115*8260db75SStella Laurenzo       return foundIt->second;
116*8260db75SStella Laurenzo     }
117*8260db75SStella Laurenzo   }
118*8260db75SStella Laurenzo 
119*8260db75SStella Laurenzo   // Not found. Load the dialect namespace.
120*8260db75SStella Laurenzo   auto split = operationName.split('.');
121*8260db75SStella Laurenzo   llvm::StringRef dialectNamespace = split.first;
122*8260db75SStella Laurenzo   loadDialectModule(dialectNamespace);
123*8260db75SStella Laurenzo 
124*8260db75SStella Laurenzo   // Attempt to find from the canonical map and cache.
125*8260db75SStella Laurenzo   {
126*8260db75SStella Laurenzo     py::gil_scoped_acquire();
127*8260db75SStella Laurenzo     auto foundIt = rawOpViewClassMap.find(operationName);
128*8260db75SStella Laurenzo     if (foundIt != rawOpViewClassMap.end()) {
129*8260db75SStella Laurenzo       if (foundIt->second.is_none())
130*8260db75SStella Laurenzo         return llvm::None;
131*8260db75SStella Laurenzo       assert(foundIt->second && "py::object is defined");
132*8260db75SStella Laurenzo       // Positive cache.
133*8260db75SStella Laurenzo       rawOpViewClassMapCache[operationName] = foundIt->second;
134*8260db75SStella Laurenzo       return foundIt->second;
135*8260db75SStella Laurenzo     } else {
136*8260db75SStella Laurenzo       // Negative cache.
137*8260db75SStella Laurenzo       rawOpViewClassMap[operationName] = py::none();
138*8260db75SStella Laurenzo       return llvm::None;
139*8260db75SStella Laurenzo     }
140*8260db75SStella Laurenzo   }
141*8260db75SStella Laurenzo }
142*8260db75SStella Laurenzo 
143*8260db75SStella Laurenzo void PyGlobals::clearImportCache() {
144*8260db75SStella Laurenzo   py::gil_scoped_acquire();
145*8260db75SStella Laurenzo   loadedDialectModulesCache.clear();
146*8260db75SStella Laurenzo   rawOpViewClassMapCache.clear();
147*8260db75SStella Laurenzo }
148*8260db75SStella Laurenzo 
149013b9322SStella Laurenzo // -----------------------------------------------------------------------------
150013b9322SStella Laurenzo // Module initialization.
151013b9322SStella Laurenzo // -----------------------------------------------------------------------------
152013b9322SStella Laurenzo 
153722475a3SStella Laurenzo PYBIND11_MODULE(_mlir, m) {
154722475a3SStella Laurenzo   m.doc() = "MLIR Python Native Extension";
155722475a3SStella Laurenzo 
156013b9322SStella Laurenzo   py::class_<PyGlobals>(m, "_Globals")
157013b9322SStella Laurenzo       .def_property("dialect_search_modules",
158013b9322SStella Laurenzo                     &PyGlobals::getDialectSearchPrefixes,
159013b9322SStella Laurenzo                     &PyGlobals::setDialectSearchPrefixes)
160013b9322SStella Laurenzo       .def("append_dialect_search_prefix",
161013b9322SStella Laurenzo            [](PyGlobals &self, std::string moduleName) {
162013b9322SStella Laurenzo              self.getDialectSearchPrefixes().push_back(std::move(moduleName));
163*8260db75SStella Laurenzo              self.clearImportCache();
164013b9322SStella Laurenzo            })
165013b9322SStella Laurenzo       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
166013b9322SStella Laurenzo            "Testing hook for directly registering a dialect")
167013b9322SStella Laurenzo       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
168013b9322SStella Laurenzo            "Testing hook for directly registering an operation");
169013b9322SStella Laurenzo 
170013b9322SStella Laurenzo   // Aside from making the globals accessible to python, having python manage
171013b9322SStella Laurenzo   // it is necessary to make sure it is destroyed (and releases its python
172013b9322SStella Laurenzo   // resources) properly.
173013b9322SStella Laurenzo   m.attr("globals") =
174013b9322SStella Laurenzo       py::cast(new PyGlobals, py::return_value_policy::take_ownership);
175013b9322SStella Laurenzo 
176013b9322SStella Laurenzo   // Registration decorators.
177013b9322SStella Laurenzo   m.def(
178013b9322SStella Laurenzo       "register_dialect",
179013b9322SStella Laurenzo       [](py::object pyClass) {
180013b9322SStella Laurenzo         std::string dialectNamespace =
181013b9322SStella Laurenzo             pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
182013b9322SStella Laurenzo         PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
183013b9322SStella Laurenzo         return pyClass;
184013b9322SStella Laurenzo       },
185013b9322SStella Laurenzo       "Class decorator for registering a custom Dialect wrapper");
186013b9322SStella Laurenzo   m.def(
187013b9322SStella Laurenzo       "register_operation",
188013b9322SStella Laurenzo       [](py::object dialectClass) -> py::cpp_function {
189013b9322SStella Laurenzo         return py::cpp_function(
190013b9322SStella Laurenzo             [dialectClass](py::object opClass) -> py::object {
191013b9322SStella Laurenzo               std::string operationName =
192013b9322SStella Laurenzo                   opClass.attr("OPERATION_NAME").cast<std::string>();
193013b9322SStella Laurenzo               auto rawSubclass = PyOpView::createRawSubclass(opClass);
194013b9322SStella Laurenzo               PyGlobals::get().registerOperationImpl(operationName, opClass,
195013b9322SStella Laurenzo                                                      rawSubclass);
196013b9322SStella Laurenzo 
197013b9322SStella Laurenzo               // Dict-stuff the new opClass by name onto the dialect class.
198013b9322SStella Laurenzo               py::object opClassName = opClass.attr("__name__");
199013b9322SStella Laurenzo               dialectClass.attr(opClassName) = opClass;
200013b9322SStella Laurenzo 
201013b9322SStella Laurenzo               // Now create a special "Raw" subclass that passes through
202013b9322SStella Laurenzo               // construction to the OpView parent (bypasses the intermediate
203013b9322SStella Laurenzo               // child's __init__).
204013b9322SStella Laurenzo               opClass.attr("_Raw") = rawSubclass;
205013b9322SStella Laurenzo               return opClass;
206013b9322SStella Laurenzo             });
207013b9322SStella Laurenzo       },
208013b9322SStella Laurenzo       "Class decorator for registering a custom Operation wrapper");
209013b9322SStella Laurenzo 
210fcd2969dSzhanghb97   // Define and populate IR submodule.
211fcd2969dSzhanghb97   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
212fcd2969dSzhanghb97   populateIRSubmodule(irModule);
213722475a3SStella Laurenzo }
214