1 //===- PythonTestModule.cpp - Python extension for the PythonTest dialect -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PythonTestCAPI.h" 10 #include "mlir/Bindings/Python/PybindAdaptors.h" 11 12 namespace py = pybind11; 13 using namespace mlir::python::adaptors; 14 15 PYBIND11_MODULE(_mlirPythonTest, m) { 16 m.def( 17 "register_python_test_dialect", 18 [](MlirContext context, bool load) { 19 MlirDialectHandle pythonTestDialect = 20 mlirGetDialectHandle__python_test__(); 21 mlirDialectHandleRegisterDialect(pythonTestDialect, context); 22 if (load) { 23 mlirDialectHandleLoadDialect(pythonTestDialect, context); 24 } 25 }, 26 py::arg("context"), py::arg("load") = true); 27 28 mlir_attribute_subclass(m, "TestAttr", 29 mlirAttributeIsAPythonTestTestAttribute) 30 .def_classmethod( 31 "get", 32 [](py::object cls, MlirContext ctx) { 33 return cls(mlirPythonTestTestAttributeGet(ctx)); 34 }, 35 py::arg("cls"), py::arg("context") = py::none()); 36 mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType) 37 .def_classmethod( 38 "get", 39 [](py::object cls, MlirContext ctx) { 40 return cls(mlirPythonTestTestTypeGet(ctx)); 41 }, 42 py::arg("cls"), py::arg("context") = py::none()); 43 } 44