1 //===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
10 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11 
12 #include "mlir-c/Support.h"
13 #include "llvm/ADT/Optional.h"
14 #include "llvm/ADT/Twine.h"
15 
16 #include <pybind11/pybind11.h>
17 #include <pybind11/stl.h>
18 
19 namespace mlir {
20 namespace python {
21 
22 // Sets a python error, ready to be thrown to return control back to the
23 // python runtime.
24 // Correct usage:
25 //   throw SetPyError(PyExc_ValueError, "Foobar'd");
26 pybind11::error_already_set SetPyError(PyObject *excClass,
27                                        const llvm::Twine &message);
28 
29 /// CRTP template for special wrapper types that are allowed to be passed in as
30 /// 'None' function arguments and can be resolved by some global mechanic if
31 /// so. Such types will raise an error if this global resolution fails, and
32 /// it is actually illegal for them to ever be unresolved. From a user
33 /// perspective, they behave like a smart ptr to the underlying type (i.e.
34 /// 'get' method and operator-> overloaded).
35 ///
36 /// Derived types must provide a method, which is called when an environmental
37 /// resolution is required. It must raise an exception if resolution fails:
38 ///   static ReferrentTy &resolve()
39 ///
40 /// They must also provide a parameter description that will be used in
41 /// error messages about mismatched types:
42 ///   static constexpr const char kTypeDescription[] = "<Description>";
43 
44 template <typename DerivedTy, typename T>
45 class Defaulting {
46 public:
47   using ReferrentTy = T;
48   /// Type casters require the type to be default constructible, but using
49   /// such an instance is illegal.
50   Defaulting() = default;
51   Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
52 
53   ReferrentTy *get() const { return referrent; }
54   ReferrentTy *operator->() { return referrent; }
55 
56 private:
57   ReferrentTy *referrent = nullptr;
58 };
59 
60 } // namespace python
61 } // namespace mlir
62 
63 namespace pybind11 {
64 namespace detail {
65 
66 template <typename DefaultingTy>
67 struct MlirDefaultingCaster {
68   PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription));
69 
70   bool load(pybind11::handle src, bool) {
71     if (src.is_none()) {
72       // Note that we do want an exception to propagate from here as it will be
73       // the most informative.
74       value = DefaultingTy{DefaultingTy::resolve()};
75       return true;
76     }
77 
78     // Unlike many casters that chain, these casters are expected to always
79     // succeed, so instead of doing an isinstance check followed by a cast,
80     // just cast in one step and handle the exception. Returning false (vs
81     // letting the exception propagate) causes higher level signature parsing
82     // code to produce nice error messages (other than "Cannot cast...").
83     try {
84       value = DefaultingTy{
85           pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)};
86       return true;
87     } catch (std::exception &) {
88       return false;
89     }
90   }
91 
92   static handle cast(DefaultingTy src, return_value_policy policy,
93                      handle parent) {
94     return pybind11::cast(src, policy);
95   }
96 };
97 
98 template <typename T>
99 struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
100 } // namespace detail
101 } // namespace pybind11
102 
103 //------------------------------------------------------------------------------
104 // Conversion utilities.
105 //------------------------------------------------------------------------------
106 
107 namespace mlir {
108 
109 /// Accumulates into a python string from a method that accepts an
110 /// MlirStringCallback.
111 struct PyPrintAccumulator {
112   pybind11::list parts;
113 
114   void *getUserData() { return this; }
115 
116   MlirStringCallback getCallback() {
117     return [](MlirStringRef part, void *userData) {
118       PyPrintAccumulator *printAccum =
119           static_cast<PyPrintAccumulator *>(userData);
120       pybind11::str pyPart(part.data,
121                            part.length); // Decodes as UTF-8 by default.
122       printAccum->parts.append(std::move(pyPart));
123     };
124   }
125 
126   pybind11::str join() {
127     pybind11::str delim("", 0);
128     return delim.attr("join")(parts);
129   }
130 };
131 
132 /// Accumulates int a python file-like object, either writing text (default)
133 /// or binary.
134 class PyFileAccumulator {
135 public:
136   PyFileAccumulator(const pybind11::object &fileObject, bool binary)
137       : pyWriteFunction(fileObject.attr("write")), binary(binary) {}
138 
139   void *getUserData() { return this; }
140 
141   MlirStringCallback getCallback() {
142     return [](MlirStringRef part, void *userData) {
143       pybind11::gil_scoped_acquire acquire;
144       PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
145       if (accum->binary) {
146         // Note: Still has to copy and not avoidable with this API.
147         pybind11::bytes pyBytes(part.data, part.length);
148         accum->pyWriteFunction(pyBytes);
149       } else {
150         pybind11::str pyStr(part.data,
151                             part.length); // Decodes as UTF-8 by default.
152         accum->pyWriteFunction(pyStr);
153       }
154     };
155   }
156 
157 private:
158   pybind11::object pyWriteFunction;
159   bool binary;
160 };
161 
162 /// Accumulates into a python string from a method that is expected to make
163 /// one (no more, no less) call to the callback (asserts internally on
164 /// violation).
165 struct PySinglePartStringAccumulator {
166   void *getUserData() { return this; }
167 
168   MlirStringCallback getCallback() {
169     return [](MlirStringRef part, void *userData) {
170       PySinglePartStringAccumulator *accum =
171           static_cast<PySinglePartStringAccumulator *>(userData);
172       assert(!accum->invoked &&
173              "PySinglePartStringAccumulator called back multiple times");
174       accum->invoked = true;
175       accum->value = pybind11::str(part.data, part.length);
176     };
177   }
178 
179   pybind11::str takeValue() {
180     assert(invoked && "PySinglePartStringAccumulator not called back");
181     return std::move(value);
182   }
183 
184 private:
185   pybind11::str value;
186   bool invoked = false;
187 };
188 
189 /// A CRTP base class for pseudo-containers willing to support Python-type
190 /// slicing access on top of indexed access. Calling ::bind on this class
191 /// will define `__len__` as well as `__getitem__` with integer and slice
192 /// arguments.
193 ///
194 /// This is intended for pseudo-containers that can refer to arbitrary slices of
195 /// underlying storage indexed by a single integer. Indexing those with an
196 /// integer produces an instance of ElementTy. Indexing those with a slice
197 /// produces a new instance of Derived, which can be sliced further.
198 ///
199 /// A derived class must provide the following:
200 ///   - a `static const char *pyClassName ` field containing the name of the
201 ///     Python class to bind;
202 ///   - an instance method `intptr_t getNumElements()` that returns the number
203 ///     of elements in the backing container (NOT that of the slice);
204 ///   - an instance method `ElementTy getElement(intptr_t)` that returns a
205 ///     single element at the given index.
206 ///   - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
207 ///     constructs a new instance of the derived pseudo-container with the
208 ///     given slice parameters (to be forwarded to the Sliceable constructor).
209 ///
210 /// The getNumElements() and getElement(intptr_t) callbacks must not throw.
211 ///
212 /// A derived class may additionally define:
213 ///   - a `static void bindDerived(ClassTy &)` method to bind additional methods
214 ///     the python class.
215 template <typename Derived, typename ElementTy>
216 class Sliceable {
217 protected:
218   using ClassTy = pybind11::class_<Derived>;
219 
220   // Transforms `index` into a legal value to access the underlying sequence.
221   // Returns <0 on failure.
222   intptr_t wrapIndex(intptr_t index) {
223     if (index < 0)
224       index = length + index;
225     if (index < 0 || index >= length)
226       return -1;
227     return index;
228   }
229 
230   /// Returns the element at the given slice index. Supports negative indices
231   /// by taking elements in inverse order. Returns a nullptr object if out
232   /// of bounds.
233   pybind11::object getItem(intptr_t index) {
234     // Negative indices mean we count from the end.
235     index = wrapIndex(index);
236     if (index < 0) {
237       PyErr_SetString(PyExc_IndexError, "index out of range");
238       return {};
239     }
240 
241     // Compute the linear index given the current slice properties.
242     int linearIndex = index * step + startIndex;
243     assert(linearIndex >= 0 &&
244            linearIndex < static_cast<Derived *>(this)->getNumElements() &&
245            "linear index out of bounds, the slice is ill-formed");
246     return pybind11::cast(
247         static_cast<Derived *>(this)->getElement(linearIndex));
248   }
249 
250   /// Returns a new instance of the pseudo-container restricted to the given
251   /// slice. Returns a nullptr object on failure.
252   pybind11::object getItemSlice(PyObject *slice) {
253     ssize_t start, stop, extraStep, sliceLength;
254     if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
255                              &sliceLength) != 0) {
256       PyErr_SetString(PyExc_IndexError, "index out of range");
257       return {};
258     }
259     return pybind11::cast(static_cast<Derived *>(this)->slice(
260         startIndex + start * step, sliceLength, step * extraStep));
261   }
262 
263 public:
264   explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
265       : startIndex(startIndex), length(length), step(step) {
266     assert(length >= 0 && "expected non-negative slice length");
267   }
268 
269   /// Returns a new vector (mapped to Python list) containing elements from two
270   /// slices. The new vector is necessary because slices may not be contiguous
271   /// or even come from the same original sequence.
272   std::vector<ElementTy> dunderAdd(Derived &other) {
273     std::vector<ElementTy> elements;
274     elements.reserve(length + other.length);
275     for (intptr_t i = 0; i < length; ++i) {
276       elements.push_back(static_cast<Derived *>(this)->getElement(i));
277     }
278     for (intptr_t i = 0; i < other.length; ++i) {
279       elements.push_back(static_cast<Derived *>(this)->getElement(i));
280     }
281     return elements;
282   }
283 
284   /// Binds the indexing and length methods in the Python class.
285   static void bind(pybind11::module &m) {
286     auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
287                                            pybind11::module_local())
288                      .def("__add__", &Sliceable::dunderAdd);
289     Derived::bindDerived(clazz);
290 
291     // Manually implement the sequence protocol via the C API. We do this
292     // because it is approx 4x faster than via pybind11, largely because that
293     // formulation requires a C++ exception to be thrown to detect end of
294     // sequence.
295     // Since we are in a C-context, any C++ exception that happens here
296     // will terminate the program. There is nothing in this implementation
297     // that should throw in a non-terminal way, so we forgo further
298     // exception marshalling.
299     // See: https://github.com/pybind/pybind11/issues/2842
300     auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
301     assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
302            "must be heap type");
303     heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
304       auto self = pybind11::cast<Derived *>(rawSelf);
305       return self->length;
306     };
307     // sq_item is called as part of the sequence protocol for iteration,
308     // list construction, etc.
309     heap_type->as_sequence.sq_item =
310         +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
311       auto self = pybind11::cast<Derived *>(rawSelf);
312       return self->getItem(index).release().ptr();
313     };
314     // mp_subscript is used for both slices and integer lookups.
315     heap_type->as_mapping.mp_subscript =
316         +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
317       auto self = pybind11::cast<Derived *>(rawSelf);
318       Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
319       if (!PyErr_Occurred()) {
320         // Integer indexing.
321         return self->getItem(index).release().ptr();
322       }
323       PyErr_Clear();
324 
325       // Assume slice-based indexing.
326       if (PySlice_Check(rawSubscript)) {
327         return self->getItemSlice(rawSubscript).release().ptr();
328       }
329 
330       PyErr_SetString(PyExc_ValueError, "expected integer or slice");
331       return nullptr;
332     };
333   }
334 
335   /// Hook for derived classes willing to bind more methods.
336   static void bindDerived(ClassTy &) {}
337 
338 private:
339   intptr_t startIndex;
340   intptr_t length;
341   intptr_t step;
342 };
343 
344 } // namespace mlir
345 
346 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
347