1722475a3SStella Laurenzo //===- MainModule.cpp - Main pybind module --------------------------------===//
2722475a3SStella Laurenzo //
3722475a3SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4722475a3SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5722475a3SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6722475a3SStella Laurenzo //
7722475a3SStella Laurenzo //===----------------------------------------------------------------------===//
8722475a3SStella Laurenzo 
9722475a3SStella Laurenzo #include <tuple>
10722475a3SStella Laurenzo 
11*013b9322SStella Laurenzo #include "PybindUtils.h"
12722475a3SStella Laurenzo 
13*013b9322SStella Laurenzo #include "Globals.h"
14fcd2969dSzhanghb97 #include "IRModules.h"
15722475a3SStella Laurenzo 
1695b77f2eSStella Laurenzo namespace py = pybind11;
17722475a3SStella Laurenzo using namespace mlir;
1895b77f2eSStella Laurenzo using namespace mlir::python;
19722475a3SStella Laurenzo 
20*013b9322SStella Laurenzo // -----------------------------------------------------------------------------
21*013b9322SStella Laurenzo // PyGlobals
22*013b9322SStella Laurenzo // -----------------------------------------------------------------------------
23*013b9322SStella Laurenzo 
24*013b9322SStella Laurenzo PyGlobals *PyGlobals::instance = nullptr;
25*013b9322SStella Laurenzo 
26*013b9322SStella Laurenzo PyGlobals::PyGlobals() {
27*013b9322SStella Laurenzo   assert(!instance && "PyGlobals already constructed");
28*013b9322SStella Laurenzo   instance = this;
29*013b9322SStella Laurenzo }
30*013b9322SStella Laurenzo 
31*013b9322SStella Laurenzo PyGlobals::~PyGlobals() { instance = nullptr; }
32*013b9322SStella Laurenzo 
33*013b9322SStella Laurenzo void PyGlobals::loadDialectModule(const std::string &dialectNamespace) {
34*013b9322SStella Laurenzo   if (loadedDialectModules.contains(dialectNamespace))
35*013b9322SStella Laurenzo     return;
36*013b9322SStella Laurenzo   // Since re-entrancy is possible, make a copy of the search prefixes.
37*013b9322SStella Laurenzo   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
38*013b9322SStella Laurenzo   py::object loaded;
39*013b9322SStella Laurenzo   for (std::string moduleName : localSearchPrefixes) {
40*013b9322SStella Laurenzo     moduleName.push_back('.');
41*013b9322SStella Laurenzo     moduleName.append(dialectNamespace);
42*013b9322SStella Laurenzo 
43*013b9322SStella Laurenzo     try {
44*013b9322SStella Laurenzo       loaded = py::module::import(moduleName.c_str());
45*013b9322SStella Laurenzo     } catch (py::error_already_set &e) {
46*013b9322SStella Laurenzo       if (e.matches(PyExc_ModuleNotFoundError)) {
47*013b9322SStella Laurenzo         continue;
48*013b9322SStella Laurenzo       } else {
49*013b9322SStella Laurenzo         throw;
50*013b9322SStella Laurenzo       }
51*013b9322SStella Laurenzo     }
52*013b9322SStella Laurenzo     break;
53*013b9322SStella Laurenzo   }
54*013b9322SStella Laurenzo 
55*013b9322SStella Laurenzo   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
56*013b9322SStella Laurenzo   // may have occurred, which may do anything.
57*013b9322SStella Laurenzo   loadedDialectModules.insert(dialectNamespace);
58*013b9322SStella Laurenzo }
59*013b9322SStella Laurenzo 
60*013b9322SStella Laurenzo void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
61*013b9322SStella Laurenzo                                     py::object pyClass) {
62*013b9322SStella Laurenzo   py::object &found = dialectClassMap[dialectNamespace];
63*013b9322SStella Laurenzo   if (found) {
64*013b9322SStella Laurenzo     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
65*013b9322SStella Laurenzo                                              dialectNamespace +
66*013b9322SStella Laurenzo                                              "' is already registered.");
67*013b9322SStella Laurenzo   }
68*013b9322SStella Laurenzo   found = std::move(pyClass);
69*013b9322SStella Laurenzo }
70*013b9322SStella Laurenzo 
71*013b9322SStella Laurenzo void PyGlobals::registerOperationImpl(const std::string &operationName,
72*013b9322SStella Laurenzo                                       py::object pyClass, py::object rawClass) {
73*013b9322SStella Laurenzo   py::object &found = operationClassMap[operationName];
74*013b9322SStella Laurenzo   if (found) {
75*013b9322SStella Laurenzo     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
76*013b9322SStella Laurenzo                                              operationName +
77*013b9322SStella Laurenzo                                              "' is already registered.");
78*013b9322SStella Laurenzo   }
79*013b9322SStella Laurenzo   found = std::move(pyClass);
80*013b9322SStella Laurenzo   rawOperationClassMap[operationName] = std::move(rawClass);
81*013b9322SStella Laurenzo }
82*013b9322SStella Laurenzo 
83*013b9322SStella Laurenzo llvm::Optional<py::object>
84*013b9322SStella Laurenzo PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
85*013b9322SStella Laurenzo   loadDialectModule(dialectNamespace);
86*013b9322SStella Laurenzo   // Fast match against the class map first (common case).
87*013b9322SStella Laurenzo   const auto foundIt = dialectClassMap.find(dialectNamespace);
88*013b9322SStella Laurenzo   if (foundIt != dialectClassMap.end()) {
89*013b9322SStella Laurenzo     if (foundIt->second.is_none())
90*013b9322SStella Laurenzo       return llvm::None;
91*013b9322SStella Laurenzo     assert(foundIt->second && "py::object is defined");
92*013b9322SStella Laurenzo     return foundIt->second;
93*013b9322SStella Laurenzo   }
94*013b9322SStella Laurenzo 
95*013b9322SStella Laurenzo   // Not found and loading did not yield a registration. Negative cache.
96*013b9322SStella Laurenzo   dialectClassMap[dialectNamespace] = py::none();
97*013b9322SStella Laurenzo   return llvm::None;
98*013b9322SStella Laurenzo }
99*013b9322SStella Laurenzo 
100*013b9322SStella Laurenzo // -----------------------------------------------------------------------------
101*013b9322SStella Laurenzo // Module initialization.
102*013b9322SStella Laurenzo // -----------------------------------------------------------------------------
103*013b9322SStella Laurenzo 
104722475a3SStella Laurenzo PYBIND11_MODULE(_mlir, m) {
105722475a3SStella Laurenzo   m.doc() = "MLIR Python Native Extension";
106722475a3SStella Laurenzo 
107*013b9322SStella Laurenzo   py::class_<PyGlobals>(m, "_Globals")
108*013b9322SStella Laurenzo       .def_property("dialect_search_modules",
109*013b9322SStella Laurenzo                     &PyGlobals::getDialectSearchPrefixes,
110*013b9322SStella Laurenzo                     &PyGlobals::setDialectSearchPrefixes)
111*013b9322SStella Laurenzo       .def("append_dialect_search_prefix",
112*013b9322SStella Laurenzo            [](PyGlobals &self, std::string moduleName) {
113*013b9322SStella Laurenzo              self.getDialectSearchPrefixes().push_back(std::move(moduleName));
114*013b9322SStella Laurenzo            })
115*013b9322SStella Laurenzo       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
116*013b9322SStella Laurenzo            "Testing hook for directly registering a dialect")
117*013b9322SStella Laurenzo       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
118*013b9322SStella Laurenzo            "Testing hook for directly registering an operation");
119*013b9322SStella Laurenzo 
120*013b9322SStella Laurenzo   // Aside from making the globals accessible to python, having python manage
121*013b9322SStella Laurenzo   // it is necessary to make sure it is destroyed (and releases its python
122*013b9322SStella Laurenzo   // resources) properly.
123*013b9322SStella Laurenzo   m.attr("globals") =
124*013b9322SStella Laurenzo       py::cast(new PyGlobals, py::return_value_policy::take_ownership);
125*013b9322SStella Laurenzo 
126*013b9322SStella Laurenzo   // Registration decorators.
127*013b9322SStella Laurenzo   m.def(
128*013b9322SStella Laurenzo       "register_dialect",
129*013b9322SStella Laurenzo       [](py::object pyClass) {
130*013b9322SStella Laurenzo         std::string dialectNamespace =
131*013b9322SStella Laurenzo             pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
132*013b9322SStella Laurenzo         PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
133*013b9322SStella Laurenzo         return pyClass;
134*013b9322SStella Laurenzo       },
135*013b9322SStella Laurenzo       "Class decorator for registering a custom Dialect wrapper");
136*013b9322SStella Laurenzo   m.def(
137*013b9322SStella Laurenzo       "register_operation",
138*013b9322SStella Laurenzo       [](py::object dialectClass) -> py::cpp_function {
139*013b9322SStella Laurenzo         return py::cpp_function(
140*013b9322SStella Laurenzo             [dialectClass](py::object opClass) -> py::object {
141*013b9322SStella Laurenzo               std::string operationName =
142*013b9322SStella Laurenzo                   opClass.attr("OPERATION_NAME").cast<std::string>();
143*013b9322SStella Laurenzo               auto rawSubclass = PyOpView::createRawSubclass(opClass);
144*013b9322SStella Laurenzo               PyGlobals::get().registerOperationImpl(operationName, opClass,
145*013b9322SStella Laurenzo                                                      rawSubclass);
146*013b9322SStella Laurenzo 
147*013b9322SStella Laurenzo               // Dict-stuff the new opClass by name onto the dialect class.
148*013b9322SStella Laurenzo               py::object opClassName = opClass.attr("__name__");
149*013b9322SStella Laurenzo               dialectClass.attr(opClassName) = opClass;
150*013b9322SStella Laurenzo 
151*013b9322SStella Laurenzo               // Now create a special "Raw" subclass that passes through
152*013b9322SStella Laurenzo               // construction to the OpView parent (bypasses the intermediate
153*013b9322SStella Laurenzo               // child's __init__).
154*013b9322SStella Laurenzo               opClass.attr("_Raw") = rawSubclass;
155*013b9322SStella Laurenzo               return opClass;
156*013b9322SStella Laurenzo             });
157*013b9322SStella Laurenzo       },
158*013b9322SStella Laurenzo       "Class decorator for registering a custom Operation wrapper");
159*013b9322SStella Laurenzo 
160fcd2969dSzhanghb97   // Define and populate IR submodule.
161fcd2969dSzhanghb97   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
162fcd2969dSzhanghb97   populateIRSubmodule(irModule);
163722475a3SStella Laurenzo }
164