1722475a3SStella Laurenzo //===- MainModule.cpp - Main pybind module --------------------------------===// 2722475a3SStella Laurenzo // 3722475a3SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4722475a3SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5722475a3SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6722475a3SStella Laurenzo // 7722475a3SStella Laurenzo //===----------------------------------------------------------------------===// 8722475a3SStella Laurenzo 9722475a3SStella Laurenzo #include <tuple> 10722475a3SStella Laurenzo 11013b9322SStella Laurenzo #include "PybindUtils.h" 12722475a3SStella Laurenzo 13f13893f6SStella Laurenzo #include "Dialects.h" 14013b9322SStella Laurenzo #include "Globals.h" 15436c6c9cSStella Laurenzo #include "IRModule.h" 16dc43f785SMehdi Amini #include "Pass.h" 17722475a3SStella Laurenzo 1895b77f2eSStella Laurenzo namespace py = pybind11; 19722475a3SStella Laurenzo using namespace mlir; 2095b77f2eSStella Laurenzo using namespace mlir::python; 21722475a3SStella Laurenzo 22013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 23013b9322SStella Laurenzo // Module initialization. 24013b9322SStella Laurenzo // ----------------------------------------------------------------------------- 25013b9322SStella Laurenzo 26722475a3SStella Laurenzo PYBIND11_MODULE(_mlir, m) { 27722475a3SStella Laurenzo m.doc() = "MLIR Python Native Extension"; 28722475a3SStella Laurenzo 29f05ff4f7SStella Laurenzo py::class_<PyGlobals>(m, "_Globals", py::module_local()) 30013b9322SStella Laurenzo .def_property("dialect_search_modules", 31013b9322SStella Laurenzo &PyGlobals::getDialectSearchPrefixes, 32013b9322SStella Laurenzo &PyGlobals::setDialectSearchPrefixes) 33*a6e7d024SStella Laurenzo .def( 34*a6e7d024SStella Laurenzo "append_dialect_search_prefix", 35013b9322SStella Laurenzo [](PyGlobals &self, std::string moduleName) { 36013b9322SStella Laurenzo self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 378260db75SStella Laurenzo self.clearImportCache(); 38*a6e7d024SStella Laurenzo }, 39*a6e7d024SStella Laurenzo py::arg("module_name")) 40013b9322SStella Laurenzo .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 41*a6e7d024SStella Laurenzo py::arg("dialect_namespace"), py::arg("dialect_class"), 42013b9322SStella Laurenzo "Testing hook for directly registering a dialect") 43013b9322SStella Laurenzo .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 44*a6e7d024SStella Laurenzo py::arg("operation_name"), py::arg("operation_class"), 45*a6e7d024SStella Laurenzo py::arg("raw_opview_class"), 46013b9322SStella Laurenzo "Testing hook for directly registering an operation"); 47013b9322SStella Laurenzo 48013b9322SStella Laurenzo // Aside from making the globals accessible to python, having python manage 49013b9322SStella Laurenzo // it is necessary to make sure it is destroyed (and releases its python 50013b9322SStella Laurenzo // resources) properly. 51013b9322SStella Laurenzo m.attr("globals") = 52013b9322SStella Laurenzo py::cast(new PyGlobals, py::return_value_policy::take_ownership); 53013b9322SStella Laurenzo 54013b9322SStella Laurenzo // Registration decorators. 55013b9322SStella Laurenzo m.def( 56013b9322SStella Laurenzo "register_dialect", 57013b9322SStella Laurenzo [](py::object pyClass) { 58013b9322SStella Laurenzo std::string dialectNamespace = 59013b9322SStella Laurenzo pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 60013b9322SStella Laurenzo PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 61013b9322SStella Laurenzo return pyClass; 62013b9322SStella Laurenzo }, 63*a6e7d024SStella Laurenzo py::arg("dialect_class"), 64013b9322SStella Laurenzo "Class decorator for registering a custom Dialect wrapper"); 65013b9322SStella Laurenzo m.def( 66013b9322SStella Laurenzo "register_operation", 67013b9322SStella Laurenzo [](py::object dialectClass) -> py::cpp_function { 68013b9322SStella Laurenzo return py::cpp_function( 69013b9322SStella Laurenzo [dialectClass](py::object opClass) -> py::object { 70013b9322SStella Laurenzo std::string operationName = 71013b9322SStella Laurenzo opClass.attr("OPERATION_NAME").cast<std::string>(); 72013b9322SStella Laurenzo auto rawSubclass = PyOpView::createRawSubclass(opClass); 73013b9322SStella Laurenzo PyGlobals::get().registerOperationImpl(operationName, opClass, 74013b9322SStella Laurenzo rawSubclass); 75013b9322SStella Laurenzo 76013b9322SStella Laurenzo // Dict-stuff the new opClass by name onto the dialect class. 77013b9322SStella Laurenzo py::object opClassName = opClass.attr("__name__"); 78013b9322SStella Laurenzo dialectClass.attr(opClassName) = opClass; 79013b9322SStella Laurenzo 80013b9322SStella Laurenzo // Now create a special "Raw" subclass that passes through 81013b9322SStella Laurenzo // construction to the OpView parent (bypasses the intermediate 82013b9322SStella Laurenzo // child's __init__). 83013b9322SStella Laurenzo opClass.attr("_Raw") = rawSubclass; 84013b9322SStella Laurenzo return opClass; 85013b9322SStella Laurenzo }); 86013b9322SStella Laurenzo }, 87*a6e7d024SStella Laurenzo py::arg("dialect_class"), 88*a6e7d024SStella Laurenzo "Produce a class decorator for registering an Operation class as part of " 89*a6e7d024SStella Laurenzo "a dialect"); 90013b9322SStella Laurenzo 91fcd2969dSzhanghb97 // Define and populate IR submodule. 92fcd2969dSzhanghb97 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 93436c6c9cSStella Laurenzo populateIRCore(irModule); 94436c6c9cSStella Laurenzo populateIRAffine(irModule); 95436c6c9cSStella Laurenzo populateIRAttributes(irModule); 9614c92070SAlex Zinenko populateIRInterfaces(irModule); 97436c6c9cSStella Laurenzo populateIRTypes(irModule); 98dc43f785SMehdi Amini 99dc43f785SMehdi Amini // Define and populate PassManager submodule. 100dc43f785SMehdi Amini auto passModule = 101dc43f785SMehdi Amini m.def_submodule("passmanager", "MLIR Pass Management Bindings"); 102dc43f785SMehdi Amini populatePassManagerSubmodule(passModule); 10313cb4317SMehdi Amini 104f13893f6SStella Laurenzo // Define and populate dialect submodules. 10543b9fa3cSNicolas Vasilache auto dialectsModule = m.def_submodule("dialects"); 10643b9fa3cSNicolas Vasilache auto linalgModule = dialectsModule.def_submodule("linalg"); 10743b9fa3cSNicolas Vasilache populateDialectLinalgSubmodule(linalgModule); 108f13893f6SStella Laurenzo populateDialectSparseTensorSubmodule( 109f13893f6SStella Laurenzo dialectsModule.def_submodule("sparse_tensor"), irModule); 110722475a3SStella Laurenzo } 111