1 //===- PybindAdaptors.h - Adaptors for interop with MLIR APIs -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file contains adaptors for clients of the core MLIR Python APIs to
9 // interop via MLIR CAPI types. The facilities here do not depend on
10 // implementation details of the MLIR Python API and do not introduce C++-level
11 // dependencies with it (requiring only Python and CAPI-level dependencies).
12 //
13 // It is encouraged to be used both in-tree and out-of-tree. For in-tree use
14 // cases, it should be used for dialect implementations (versus relying on
15 // Pybind-based internals of the core libraries).
16 //===----------------------------------------------------------------------===//
17 
18 #ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
19 #define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
20 
21 #include <pybind11/pybind11.h>
22 #include <pybind11/pytypes.h>
23 #include <pybind11/stl.h>
24 
25 #include "mlir-c/Bindings/Python/Interop.h"
26 #include "mlir-c/IR.h"
27 
28 #include "llvm/ADT/Optional.h"
29 #include "llvm/ADT/Twine.h"
30 
31 namespace py = pybind11;
32 
33 // Raw CAPI type casters need to be declared before use, so always include them
34 // first.
35 namespace pybind11 {
36 namespace detail {
37 
38 template <typename T>
39 struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
40 
41 /// Helper to convert a presumed MLIR API object to a capsule, accepting either
42 /// an explicit Capsule (which can happen when two C APIs are communicating
43 /// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR
44 /// attribute (through which supported MLIR Python API objects export their
45 /// contained API pointer as a capsule). Throws a type error if the object is
46 /// neither. This is intended to be used from type casters, which are invoked
47 /// with a raw handle (unowned). The returned object's lifetime may not extend
48 /// beyond the apiObject handle without explicitly having its refcount increased
49 /// (i.e. on return).
50 static py::object mlirApiObjectToCapsule(py::handle apiObject) {
51   if (PyCapsule_CheckExact(apiObject.ptr()))
52     return py::reinterpret_borrow<py::object>(apiObject);
53   if (!py::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) {
54     auto repr = py::repr(apiObject).cast<std::string>();
55     throw py::type_error(
56         (llvm::Twine("Expected an MLIR object (got ") + repr + ").").str());
57   }
58   return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
59 }
60 
61 // Note: Currently all of the following support cast from py::object to the
62 // Mlir* C-API type, but only a few light-weight, context-bound ones
63 // implicitly cast the other way because the use case has not yet emerged and
64 // ownership is unclear.
65 
66 /// Casts object <-> MlirAffineMap.
67 template <>
68 struct type_caster<MlirAffineMap> {
69   PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap"));
70   bool load(handle src, bool) {
71     py::object capsule = mlirApiObjectToCapsule(src);
72     value = mlirPythonCapsuleToAffineMap(capsule.ptr());
73     if (mlirAffineMapIsNull(value)) {
74       return false;
75     }
76     return !mlirAffineMapIsNull(value);
77   }
78   static handle cast(MlirAffineMap v, return_value_policy, handle) {
79     py::object capsule =
80         py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(v));
81     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
82         .attr("AffineMap")
83         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
84         .release();
85   }
86 };
87 
88 /// Casts object <-> MlirAttribute.
89 template <>
90 struct type_caster<MlirAttribute> {
91   PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute"));
92   bool load(handle src, bool) {
93     py::object capsule = mlirApiObjectToCapsule(src);
94     value = mlirPythonCapsuleToAttribute(capsule.ptr());
95     return !mlirAttributeIsNull(value);
96   }
97   static handle cast(MlirAttribute v, return_value_policy, handle) {
98     py::object capsule =
99         py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(v));
100     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
101         .attr("Attribute")
102         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
103         .release();
104   }
105 };
106 
107 /// Casts object -> MlirContext.
108 template <>
109 struct type_caster<MlirContext> {
110   PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext"));
111   bool load(handle src, bool) {
112     if (src.is_none()) {
113       // Gets the current thread-bound context.
114       // TODO: This raises an error of "No current context" currently.
115       // Update the implementation to pretty-print the helpful error that the
116       // core implementations print in this case.
117       src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
118                 .attr("Context")
119                 .attr("current");
120     }
121     py::object capsule = mlirApiObjectToCapsule(src);
122     value = mlirPythonCapsuleToContext(capsule.ptr());
123     return !mlirContextIsNull(value);
124   }
125 };
126 
127 /// Casts object <-> MlirDialectRegistry.
128 template <>
129 struct type_caster<MlirDialectRegistry> {
130   PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry"));
131   bool load(handle src, bool) {
132     py::object capsule = mlirApiObjectToCapsule(src);
133     value = mlirPythonCapsuleToDialectRegistry(capsule.ptr());
134     return !mlirDialectRegistryIsNull(value);
135   }
136   static handle cast(MlirDialectRegistry v, return_value_policy, handle) {
137     py::object capsule = py::reinterpret_steal<py::object>(
138         mlirPythonDialectRegistryToCapsule(v));
139     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
140         .attr("DialectRegistry")
141         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
142         .release();
143   }
144 };
145 
146 /// Casts object <-> MlirLocation.
147 template <>
148 struct type_caster<MlirLocation> {
149   PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation"));
150   bool load(handle src, bool) {
151     if (src.is_none()) {
152       // Gets the current thread-bound context.
153       src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
154                 .attr("Location")
155                 .attr("current");
156     }
157     py::object capsule = mlirApiObjectToCapsule(src);
158     value = mlirPythonCapsuleToLocation(capsule.ptr());
159     return !mlirLocationIsNull(value);
160   }
161   static handle cast(MlirLocation v, return_value_policy, handle) {
162     py::object capsule =
163         py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(v));
164     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
165         .attr("Location")
166         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
167         .release();
168   }
169 };
170 
171 /// Casts object <-> MlirModule.
172 template <>
173 struct type_caster<MlirModule> {
174   PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule"));
175   bool load(handle src, bool) {
176     py::object capsule = mlirApiObjectToCapsule(src);
177     value = mlirPythonCapsuleToModule(capsule.ptr());
178     return !mlirModuleIsNull(value);
179   }
180   static handle cast(MlirModule v, return_value_policy, handle) {
181     py::object capsule =
182         py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(v));
183     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
184         .attr("Module")
185         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
186         .release();
187   };
188 };
189 
190 /// Casts object <-> MlirOperation.
191 template <>
192 struct type_caster<MlirOperation> {
193   PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation"));
194   bool load(handle src, bool) {
195     py::object capsule = mlirApiObjectToCapsule(src);
196     value = mlirPythonCapsuleToOperation(capsule.ptr());
197     return !mlirOperationIsNull(value);
198   }
199   static handle cast(MlirOperation v, return_value_policy, handle) {
200     if (v.ptr == nullptr)
201       return py::none();
202     py::object capsule =
203         py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(v));
204     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
205         .attr("Operation")
206         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
207         .release();
208   };
209 };
210 
211 /// Casts object -> MlirPassManager.
212 template <>
213 struct type_caster<MlirPassManager> {
214   PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"));
215   bool load(handle src, bool) {
216     py::object capsule = mlirApiObjectToCapsule(src);
217     value = mlirPythonCapsuleToPassManager(capsule.ptr());
218     return !mlirPassManagerIsNull(value);
219   }
220 };
221 
222 /// Casts object <-> MlirType.
223 template <>
224 struct type_caster<MlirType> {
225   PYBIND11_TYPE_CASTER(MlirType, _("MlirType"));
226   bool load(handle src, bool) {
227     py::object capsule = mlirApiObjectToCapsule(src);
228     value = mlirPythonCapsuleToType(capsule.ptr());
229     return !mlirTypeIsNull(value);
230   }
231   static handle cast(MlirType t, return_value_policy, handle) {
232     py::object capsule =
233         py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(t));
234     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
235         .attr("Type")
236         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
237         .release();
238   }
239 };
240 
241 } // namespace detail
242 } // namespace pybind11
243 
244 namespace mlir {
245 namespace python {
246 namespace adaptors {
247 
248 /// Provides a facility like py::class_ for defining a new class in a scope,
249 /// but this allows extension of an arbitrary Python class, defining methods
250 /// on it is a similar way. Classes defined in this way are very similar to
251 /// if defined in Python in the usual way but use Pybind11 machinery to do
252 /// it. These are not "real" Pybind11 classes but pure Python classes with no
253 /// relation to a concrete C++ class.
254 ///
255 /// Derived from a discussion upstream:
256 ///   https://github.com/pybind/pybind11/issues/1193
257 ///   (plus a fair amount of extra curricular poking)
258 ///   TODO: If this proves useful, see about including it in pybind11.
259 class pure_subclass {
260 public:
261   pure_subclass(py::handle scope, const char *derivedClassName,
262                 const py::object &superClass) {
263     py::object pyType =
264         py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
265     py::object metaclass = pyType(superClass);
266     py::dict attributes;
267 
268     thisClass =
269         metaclass(derivedClassName, py::make_tuple(superClass), attributes);
270     scope.attr(derivedClassName) = thisClass;
271   }
272 
273   template <typename Func, typename... Extra>
274   pure_subclass &def(const char *name, Func &&f, const Extra &...extra) {
275     py::cpp_function cf(
276         std::forward<Func>(f), py::name(name), py::is_method(thisClass),
277         py::sibling(py::getattr(thisClass, name, py::none())), extra...);
278     thisClass.attr(cf.name()) = cf;
279     return *this;
280   }
281 
282   template <typename Func, typename... Extra>
283   pure_subclass &def_property_readonly(const char *name, Func &&f,
284                                        const Extra &...extra) {
285     py::cpp_function cf(
286         std::forward<Func>(f), py::name(name), py::is_method(thisClass),
287         py::sibling(py::getattr(thisClass, name, py::none())), extra...);
288     auto builtinProperty =
289         py::reinterpret_borrow<py::object>((PyObject *)&PyProperty_Type);
290     thisClass.attr(name) = builtinProperty(cf);
291     return *this;
292   }
293 
294   template <typename Func, typename... Extra>
295   pure_subclass &def_staticmethod(const char *name, Func &&f,
296                                   const Extra &...extra) {
297     static_assert(!std::is_member_function_pointer<Func>::value,
298                   "def_staticmethod(...) called with a non-static member "
299                   "function pointer");
300     py::cpp_function cf(
301         std::forward<Func>(f), py::name(name), py::scope(thisClass),
302         py::sibling(py::getattr(thisClass, name, py::none())), extra...);
303     thisClass.attr(cf.name()) = py::staticmethod(cf);
304     return *this;
305   }
306 
307   template <typename Func, typename... Extra>
308   pure_subclass &def_classmethod(const char *name, Func &&f,
309                                  const Extra &...extra) {
310     static_assert(!std::is_member_function_pointer<Func>::value,
311                   "def_classmethod(...) called with a non-static member "
312                   "function pointer");
313     py::cpp_function cf(
314         std::forward<Func>(f), py::name(name), py::scope(thisClass),
315         py::sibling(py::getattr(thisClass, name, py::none())), extra...);
316     thisClass.attr(cf.name()) =
317         py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
318     return *this;
319   }
320 
321   py::object get_class() const { return thisClass; }
322 
323 protected:
324   py::object superClass;
325   py::object thisClass;
326 };
327 
328 /// Creates a custom subclass of mlir.ir.Attribute, implementing a casting
329 /// constructor and type checking methods.
330 class mlir_attribute_subclass : public pure_subclass {
331 public:
332   using IsAFunctionTy = bool (*)(MlirAttribute);
333 
334   /// Subclasses by looking up the super-class dynamically.
335   mlir_attribute_subclass(py::handle scope, const char *attrClassName,
336                           IsAFunctionTy isaFunction)
337       : mlir_attribute_subclass(
338             scope, attrClassName, isaFunction,
339             py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
340                 .attr("Attribute")) {}
341 
342   /// Subclasses with a provided mlir.ir.Attribute super-class. This must
343   /// be used if the subclass is being defined in the same extension module
344   /// as the mlir.ir class (otherwise, it will trigger a recursive
345   /// initialization).
346   mlir_attribute_subclass(py::handle scope, const char *typeClassName,
347                           IsAFunctionTy isaFunction, const py::object &superCls)
348       : pure_subclass(scope, typeClassName, superCls) {
349     // Casting constructor. Note that it hard, if not impossible, to properly
350     // call chain to parent `__init__` in pybind11 due to its special handling
351     // for init functions that don't have a fully constructed self-reference,
352     // which makes it impossible to forward it to `__init__` of a superclass.
353     // Instead, provide a custom `__new__` and call that of a superclass, which
354     // eventually calls `__init__` of the superclass. Since attribute subclasses
355     // have no additional members, we can just return the instance thus created
356     // without amending it.
357     std::string captureTypeName(
358         typeClassName); // As string in case if typeClassName is not static.
359     py::cpp_function newCf(
360         [superCls, isaFunction, captureTypeName](py::object cls,
361                                                  py::object otherAttribute) {
362           MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherAttribute);
363           if (!isaFunction(rawAttribute)) {
364             auto origRepr = py::repr(otherAttribute).cast<std::string>();
365             throw std::invalid_argument(
366                 (llvm::Twine("Cannot cast attribute to ") + captureTypeName +
367                  " (from " + origRepr + ")")
368                     .str());
369           }
370           py::object self = superCls.attr("__new__")(cls, otherAttribute);
371           return self;
372         },
373         py::name("__new__"), py::arg("cls"), py::arg("cast_from_attr"));
374     thisClass.attr("__new__") = newCf;
375 
376     // 'isinstance' method.
377     def_staticmethod(
378         "isinstance",
379         [isaFunction](MlirAttribute other) { return isaFunction(other); },
380         py::arg("other_attribute"));
381   }
382 };
383 
384 /// Creates a custom subclass of mlir.ir.Type, implementing a casting
385 /// constructor and type checking methods.
386 class mlir_type_subclass : public pure_subclass {
387 public:
388   using IsAFunctionTy = bool (*)(MlirType);
389 
390   /// Subclasses by looking up the super-class dynamically.
391   mlir_type_subclass(py::handle scope, const char *typeClassName,
392                      IsAFunctionTy isaFunction)
393       : mlir_type_subclass(
394             scope, typeClassName, isaFunction,
395             py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type")) {}
396 
397   /// Subclasses with a provided mlir.ir.Type super-class. This must
398   /// be used if the subclass is being defined in the same extension module
399   /// as the mlir.ir class (otherwise, it will trigger a recursive
400   /// initialization).
401   mlir_type_subclass(py::handle scope, const char *typeClassName,
402                      IsAFunctionTy isaFunction, const py::object &superCls)
403       : pure_subclass(scope, typeClassName, superCls) {
404     // Casting constructor. Note that it hard, if not impossible, to properly
405     // call chain to parent `__init__` in pybind11 due to its special handling
406     // for init functions that don't have a fully constructed self-reference,
407     // which makes it impossible to forward it to `__init__` of a superclass.
408     // Instead, provide a custom `__new__` and call that of a superclass, which
409     // eventually calls `__init__` of the superclass. Since attribute subclasses
410     // have no additional members, we can just return the instance thus created
411     // without amending it.
412     std::string captureTypeName(
413         typeClassName); // As string in case if typeClassName is not static.
414     py::cpp_function newCf(
415         [superCls, isaFunction, captureTypeName](py::object cls,
416                                                  py::object otherType) {
417           MlirType rawType = py::cast<MlirType>(otherType);
418           if (!isaFunction(rawType)) {
419             auto origRepr = py::repr(otherType).cast<std::string>();
420             throw std::invalid_argument((llvm::Twine("Cannot cast type to ") +
421                                          captureTypeName + " (from " +
422                                          origRepr + ")")
423                                             .str());
424           }
425           py::object self = superCls.attr("__new__")(cls, otherType);
426           return self;
427         },
428         py::name("__new__"), py::arg("cls"), py::arg("cast_from_type"));
429     thisClass.attr("__new__") = newCf;
430 
431     // 'isinstance' method.
432     def_staticmethod(
433         "isinstance",
434         [isaFunction](MlirType other) { return isaFunction(other); },
435         py::arg("other_type"));
436   }
437 };
438 
439 } // namespace adaptors
440 } // namespace python
441 } // namespace mlir
442 
443 #endif // MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
444