1 //===- PybindAdaptors.h - Adaptors for interop with MLIR APIs -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This file contains adaptors for clients of the core MLIR Python APIs to 9 // interop via MLIR CAPI types. The facilities here do not depend on 10 // implementation details of the MLIR Python API and do not introduce C++-level 11 // dependencies with it (requiring only Python and CAPI-level dependencies). 12 // 13 // It is encouraged to be used both in-tree and out-of-tree. For in-tree use 14 // cases, it should be used for dialect implementations (versus relying on 15 // Pybind-based internals of the core libraries). 16 //===----------------------------------------------------------------------===// 17 18 #ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H 19 #define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H 20 21 #include <pybind11/pybind11.h> 22 #include <pybind11/pytypes.h> 23 #include <pybind11/stl.h> 24 25 #include "mlir-c/Bindings/Python/Interop.h" 26 #include "mlir-c/IR.h" 27 28 #include "llvm/ADT/Optional.h" 29 #include "llvm/ADT/Twine.h" 30 31 namespace py = pybind11; 32 33 // Raw CAPI type casters need to be declared before use, so always include them 34 // first. 35 namespace pybind11 { 36 namespace detail { 37 38 template <typename T> 39 struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {}; 40 41 /// Helper to convert a presumed MLIR API object to a capsule, accepting either 42 /// an explicit Capsule (which can happen when two C APIs are communicating 43 /// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR 44 /// attribute (through which supported MLIR Python API objects export their 45 /// contained API pointer as a capsule). Throws a type error if the object is 46 /// neither. This is intended to be used from type casters, which are invoked 47 /// with a raw handle (unowned). The returned object's lifetime may not extend 48 /// beyond the apiObject handle without explicitly having its refcount increased 49 /// (i.e. on return). 50 static py::object mlirApiObjectToCapsule(py::handle apiObject) { 51 if (PyCapsule_CheckExact(apiObject.ptr())) 52 return py::reinterpret_borrow<py::object>(apiObject); 53 if (!py::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) { 54 auto repr = py::repr(apiObject).cast<std::string>(); 55 throw py::type_error( 56 (llvm::Twine("Expected an MLIR object (got ") + repr + ").").str()); 57 } 58 return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR); 59 } 60 61 // Note: Currently all of the following support cast from py::object to the 62 // Mlir* C-API type, but only a few light-weight, context-bound ones 63 // implicitly cast the other way because the use case has not yet emerged and 64 // ownership is unclear. 65 66 /// Casts object <-> MlirAffineMap. 67 template <> 68 struct type_caster<MlirAffineMap> { 69 PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap")); 70 bool load(handle src, bool) { 71 py::object capsule = mlirApiObjectToCapsule(src); 72 value = mlirPythonCapsuleToAffineMap(capsule.ptr()); 73 if (mlirAffineMapIsNull(value)) { 74 return false; 75 } 76 return !mlirAffineMapIsNull(value); 77 } 78 static handle cast(MlirAffineMap v, return_value_policy, handle) { 79 py::object capsule = 80 py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(v)); 81 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 82 .attr("AffineMap") 83 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 84 .release(); 85 } 86 }; 87 88 /// Casts object <-> MlirAttribute. 89 template <> 90 struct type_caster<MlirAttribute> { 91 PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute")); 92 bool load(handle src, bool) { 93 py::object capsule = mlirApiObjectToCapsule(src); 94 value = mlirPythonCapsuleToAttribute(capsule.ptr()); 95 return !mlirAttributeIsNull(value); 96 } 97 static handle cast(MlirAttribute v, return_value_policy, handle) { 98 py::object capsule = 99 py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(v)); 100 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 101 .attr("Attribute") 102 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 103 .release(); 104 } 105 }; 106 107 /// Casts object -> MlirContext. 108 template <> 109 struct type_caster<MlirContext> { 110 PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext")); 111 bool load(handle src, bool) { 112 if (src.is_none()) { 113 // Gets the current thread-bound context. 114 // TODO: This raises an error of "No current context" currently. 115 // Update the implementation to pretty-print the helpful error that the 116 // core implementations print in this case. 117 src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 118 .attr("Context") 119 .attr("current"); 120 } 121 py::object capsule = mlirApiObjectToCapsule(src); 122 value = mlirPythonCapsuleToContext(capsule.ptr()); 123 return !mlirContextIsNull(value); 124 } 125 }; 126 127 /// Casts object <-> MlirDialectRegistry. 128 template <> 129 struct type_caster<MlirDialectRegistry> { 130 PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry")); 131 bool load(handle src, bool) { 132 py::object capsule = mlirApiObjectToCapsule(src); 133 value = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); 134 return !mlirDialectRegistryIsNull(value); 135 } 136 static handle cast(MlirDialectRegistry v, return_value_policy, handle) { 137 py::object capsule = py::reinterpret_steal<py::object>( 138 mlirPythonDialectRegistryToCapsule(v)); 139 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 140 .attr("DialectRegistry") 141 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 142 .release(); 143 } 144 }; 145 146 /// Casts object <-> MlirLocation. 147 template <> 148 struct type_caster<MlirLocation> { 149 PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation")); 150 bool load(handle src, bool) { 151 if (src.is_none()) { 152 // Gets the current thread-bound context. 153 src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 154 .attr("Location") 155 .attr("current"); 156 } 157 py::object capsule = mlirApiObjectToCapsule(src); 158 value = mlirPythonCapsuleToLocation(capsule.ptr()); 159 return !mlirLocationIsNull(value); 160 } 161 static handle cast(MlirLocation v, return_value_policy, handle) { 162 py::object capsule = 163 py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(v)); 164 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 165 .attr("Location") 166 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 167 .release(); 168 } 169 }; 170 171 /// Casts object <-> MlirModule. 172 template <> 173 struct type_caster<MlirModule> { 174 PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule")); 175 bool load(handle src, bool) { 176 py::object capsule = mlirApiObjectToCapsule(src); 177 value = mlirPythonCapsuleToModule(capsule.ptr()); 178 return !mlirModuleIsNull(value); 179 } 180 static handle cast(MlirModule v, return_value_policy, handle) { 181 py::object capsule = 182 py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(v)); 183 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 184 .attr("Module") 185 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 186 .release(); 187 }; 188 }; 189 190 /// Casts object <-> MlirOperation. 191 template <> 192 struct type_caster<MlirOperation> { 193 PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation")); 194 bool load(handle src, bool) { 195 py::object capsule = mlirApiObjectToCapsule(src); 196 value = mlirPythonCapsuleToOperation(capsule.ptr()); 197 return !mlirOperationIsNull(value); 198 } 199 static handle cast(MlirOperation v, return_value_policy, handle) { 200 if (v.ptr == nullptr) 201 return py::none(); 202 py::object capsule = 203 py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(v)); 204 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 205 .attr("Operation") 206 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 207 .release(); 208 }; 209 }; 210 211 /// Casts object -> MlirPassManager. 212 template <> 213 struct type_caster<MlirPassManager> { 214 PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager")); 215 bool load(handle src, bool) { 216 py::object capsule = mlirApiObjectToCapsule(src); 217 value = mlirPythonCapsuleToPassManager(capsule.ptr()); 218 return !mlirPassManagerIsNull(value); 219 } 220 }; 221 222 /// Casts object <-> MlirType. 223 template <> 224 struct type_caster<MlirType> { 225 PYBIND11_TYPE_CASTER(MlirType, _("MlirType")); 226 bool load(handle src, bool) { 227 py::object capsule = mlirApiObjectToCapsule(src); 228 value = mlirPythonCapsuleToType(capsule.ptr()); 229 return !mlirTypeIsNull(value); 230 } 231 static handle cast(MlirType t, return_value_policy, handle) { 232 py::object capsule = 233 py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(t)); 234 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 235 .attr("Type") 236 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 237 .release(); 238 } 239 }; 240 241 } // namespace detail 242 } // namespace pybind11 243 244 namespace mlir { 245 namespace python { 246 namespace adaptors { 247 248 /// Provides a facility like py::class_ for defining a new class in a scope, 249 /// but this allows extension of an arbitrary Python class, defining methods 250 /// on it is a similar way. Classes defined in this way are very similar to 251 /// if defined in Python in the usual way but use Pybind11 machinery to do 252 /// it. These are not "real" Pybind11 classes but pure Python classes with no 253 /// relation to a concrete C++ class. 254 /// 255 /// Derived from a discussion upstream: 256 /// https://github.com/pybind/pybind11/issues/1193 257 /// (plus a fair amount of extra curricular poking) 258 /// TODO: If this proves useful, see about including it in pybind11. 259 class pure_subclass { 260 public: 261 pure_subclass(py::handle scope, const char *derivedClassName, 262 const py::object &superClass) { 263 py::object pyType = 264 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 265 py::object metaclass = pyType(superClass); 266 py::dict attributes; 267 268 thisClass = 269 metaclass(derivedClassName, py::make_tuple(superClass), attributes); 270 scope.attr(derivedClassName) = thisClass; 271 } 272 273 template <typename Func, typename... Extra> 274 pure_subclass &def(const char *name, Func &&f, const Extra &...extra) { 275 py::cpp_function cf( 276 std::forward<Func>(f), py::name(name), py::is_method(thisClass), 277 py::sibling(py::getattr(thisClass, name, py::none())), extra...); 278 thisClass.attr(cf.name()) = cf; 279 return *this; 280 } 281 282 template <typename Func, typename... Extra> 283 pure_subclass &def_property_readonly(const char *name, Func &&f, 284 const Extra &...extra) { 285 py::cpp_function cf( 286 std::forward<Func>(f), py::name(name), py::is_method(thisClass), 287 py::sibling(py::getattr(thisClass, name, py::none())), extra...); 288 auto builtinProperty = 289 py::reinterpret_borrow<py::object>((PyObject *)&PyProperty_Type); 290 thisClass.attr(name) = builtinProperty(cf); 291 return *this; 292 } 293 294 template <typename Func, typename... Extra> 295 pure_subclass &def_staticmethod(const char *name, Func &&f, 296 const Extra &...extra) { 297 static_assert(!std::is_member_function_pointer<Func>::value, 298 "def_staticmethod(...) called with a non-static member " 299 "function pointer"); 300 py::cpp_function cf( 301 std::forward<Func>(f), py::name(name), py::scope(thisClass), 302 py::sibling(py::getattr(thisClass, name, py::none())), extra...); 303 thisClass.attr(cf.name()) = py::staticmethod(cf); 304 return *this; 305 } 306 307 template <typename Func, typename... Extra> 308 pure_subclass &def_classmethod(const char *name, Func &&f, 309 const Extra &...extra) { 310 static_assert(!std::is_member_function_pointer<Func>::value, 311 "def_classmethod(...) called with a non-static member " 312 "function pointer"); 313 py::cpp_function cf( 314 std::forward<Func>(f), py::name(name), py::scope(thisClass), 315 py::sibling(py::getattr(thisClass, name, py::none())), extra...); 316 thisClass.attr(cf.name()) = 317 py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr())); 318 return *this; 319 } 320 321 py::object get_class() const { return thisClass; } 322 323 protected: 324 py::object superClass; 325 py::object thisClass; 326 }; 327 328 /// Creates a custom subclass of mlir.ir.Attribute, implementing a casting 329 /// constructor and type checking methods. 330 class mlir_attribute_subclass : public pure_subclass { 331 public: 332 using IsAFunctionTy = bool (*)(MlirAttribute); 333 334 /// Subclasses by looking up the super-class dynamically. 335 mlir_attribute_subclass(py::handle scope, const char *attrClassName, 336 IsAFunctionTy isaFunction) 337 : mlir_attribute_subclass( 338 scope, attrClassName, isaFunction, 339 py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 340 .attr("Attribute")) {} 341 342 /// Subclasses with a provided mlir.ir.Attribute super-class. This must 343 /// be used if the subclass is being defined in the same extension module 344 /// as the mlir.ir class (otherwise, it will trigger a recursive 345 /// initialization). 346 mlir_attribute_subclass(py::handle scope, const char *typeClassName, 347 IsAFunctionTy isaFunction, const py::object &superCls) 348 : pure_subclass(scope, typeClassName, superCls) { 349 // Casting constructor. Note that it hard, if not impossible, to properly 350 // call chain to parent `__init__` in pybind11 due to its special handling 351 // for init functions that don't have a fully constructed self-reference, 352 // which makes it impossible to forward it to `__init__` of a superclass. 353 // Instead, provide a custom `__new__` and call that of a superclass, which 354 // eventually calls `__init__` of the superclass. Since attribute subclasses 355 // have no additional members, we can just return the instance thus created 356 // without amending it. 357 std::string captureTypeName( 358 typeClassName); // As string in case if typeClassName is not static. 359 py::cpp_function newCf( 360 [superCls, isaFunction, captureTypeName](py::object cls, 361 py::object otherAttribute) { 362 MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherAttribute); 363 if (!isaFunction(rawAttribute)) { 364 auto origRepr = py::repr(otherAttribute).cast<std::string>(); 365 throw std::invalid_argument( 366 (llvm::Twine("Cannot cast attribute to ") + captureTypeName + 367 " (from " + origRepr + ")") 368 .str()); 369 } 370 py::object self = superCls.attr("__new__")(cls, otherAttribute); 371 return self; 372 }, 373 py::name("__new__"), py::arg("cls"), py::arg("cast_from_attr")); 374 thisClass.attr("__new__") = newCf; 375 376 // 'isinstance' method. 377 def_staticmethod( 378 "isinstance", 379 [isaFunction](MlirAttribute other) { return isaFunction(other); }, 380 py::arg("other_attribute")); 381 } 382 }; 383 384 /// Creates a custom subclass of mlir.ir.Type, implementing a casting 385 /// constructor and type checking methods. 386 class mlir_type_subclass : public pure_subclass { 387 public: 388 using IsAFunctionTy = bool (*)(MlirType); 389 390 /// Subclasses by looking up the super-class dynamically. 391 mlir_type_subclass(py::handle scope, const char *typeClassName, 392 IsAFunctionTy isaFunction) 393 : mlir_type_subclass( 394 scope, typeClassName, isaFunction, 395 py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type")) {} 396 397 /// Subclasses with a provided mlir.ir.Type super-class. This must 398 /// be used if the subclass is being defined in the same extension module 399 /// as the mlir.ir class (otherwise, it will trigger a recursive 400 /// initialization). 401 mlir_type_subclass(py::handle scope, const char *typeClassName, 402 IsAFunctionTy isaFunction, const py::object &superCls) 403 : pure_subclass(scope, typeClassName, superCls) { 404 // Casting constructor. Note that it hard, if not impossible, to properly 405 // call chain to parent `__init__` in pybind11 due to its special handling 406 // for init functions that don't have a fully constructed self-reference, 407 // which makes it impossible to forward it to `__init__` of a superclass. 408 // Instead, provide a custom `__new__` and call that of a superclass, which 409 // eventually calls `__init__` of the superclass. Since attribute subclasses 410 // have no additional members, we can just return the instance thus created 411 // without amending it. 412 std::string captureTypeName( 413 typeClassName); // As string in case if typeClassName is not static. 414 py::cpp_function newCf( 415 [superCls, isaFunction, captureTypeName](py::object cls, 416 py::object otherType) { 417 MlirType rawType = py::cast<MlirType>(otherType); 418 if (!isaFunction(rawType)) { 419 auto origRepr = py::repr(otherType).cast<std::string>(); 420 throw std::invalid_argument((llvm::Twine("Cannot cast type to ") + 421 captureTypeName + " (from " + 422 origRepr + ")") 423 .str()); 424 } 425 py::object self = superCls.attr("__new__")(cls, otherType); 426 return self; 427 }, 428 py::name("__new__"), py::arg("cls"), py::arg("cast_from_type")); 429 thisClass.attr("__new__") = newCf; 430 431 // 'isinstance' method. 432 def_staticmethod( 433 "isinstance", 434 [isaFunction](MlirType other) { return isaFunction(other); }, 435 py::arg("other_type")); 436 } 437 }; 438 439 } // namespace adaptors 440 } // namespace python 441 } // namespace mlir 442 443 #endif // MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H 444