1 //===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
10 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11 
12 #include "mlir-c/Support.h"
13 #include "llvm/ADT/Optional.h"
14 #include "llvm/ADT/Twine.h"
15 
16 #include <pybind11/pybind11.h>
17 #include <pybind11/stl.h>
18 
19 
20 namespace mlir {
21 namespace python {
22 
23 // Sets a python error, ready to be thrown to return control back to the
24 // python runtime.
25 // Correct usage:
26 //   throw SetPyError(PyExc_ValueError, "Foobar'd");
27 pybind11::error_already_set SetPyError(PyObject *excClass,
28                                        const llvm::Twine &message);
29 
30 /// CRTP template for special wrapper types that are allowed to be passed in as
31 /// 'None' function arguments and can be resolved by some global mechanic if
32 /// so. Such types will raise an error if this global resolution fails, and
33 /// it is actually illegal for them to ever be unresolved. From a user
34 /// perspective, they behave like a smart ptr to the underlying type (i.e.
35 /// 'get' method and operator-> overloaded).
36 ///
37 /// Derived types must provide a method, which is called when an environmental
38 /// resolution is required. It must raise an exception if resolution fails:
39 ///   static ReferrentTy &resolve()
40 ///
41 /// They must also provide a parameter description that will be used in
42 /// error messages about mismatched types:
43 ///   static constexpr const char kTypeDescription[] = "<Description>";
44 
45 template <typename DerivedTy, typename T>
46 class Defaulting {
47 public:
48   using ReferrentTy = T;
49   /// Type casters require the type to be default constructible, but using
50   /// such an instance is illegal.
51   Defaulting() = default;
52   Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
53 
54   ReferrentTy *get() { return referrent; }
55   ReferrentTy *operator->() { return referrent; }
56 
57 private:
58   ReferrentTy *referrent = nullptr;
59 };
60 
61 } // namespace python
62 } // namespace mlir
63 
64 namespace pybind11 {
65 namespace detail {
66 
67 template <typename DefaultingTy>
68 struct MlirDefaultingCaster {
69   PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription));
70 
71   bool load(pybind11::handle src, bool) {
72     if (src.is_none()) {
73       // Note that we do want an exception to propagate from here as it will be
74       // the most informative.
75       value = DefaultingTy{DefaultingTy::resolve()};
76       return true;
77     }
78 
79     // Unlike many casters that chain, these casters are expected to always
80     // succeed, so instead of doing an isinstance check followed by a cast,
81     // just cast in one step and handle the exception. Returning false (vs
82     // letting the exception propagate) causes higher level signature parsing
83     // code to produce nice error messages (other than "Cannot cast...").
84     try {
85       value = DefaultingTy{
86           pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)};
87       return true;
88     } catch (std::exception &e) {
89       return false;
90     }
91   }
92 
93   static handle cast(DefaultingTy src, return_value_policy policy,
94                      handle parent) {
95     return pybind11::cast(src, policy);
96   }
97 };
98 
99 template <typename T>
100 struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
101 } // namespace detail
102 } // namespace pybind11
103 
104 //------------------------------------------------------------------------------
105 // Conversion utilities.
106 //------------------------------------------------------------------------------
107 
108 namespace mlir {
109 
110 /// Accumulates into a python string from a method that accepts an
111 /// MlirStringCallback.
112 struct PyPrintAccumulator {
113   pybind11::list parts;
114 
115   void *getUserData() { return this; }
116 
117   MlirStringCallback getCallback() {
118     return [](const char *part, intptr_t size, void *userData) {
119       PyPrintAccumulator *printAccum =
120           static_cast<PyPrintAccumulator *>(userData);
121       pybind11::str pyPart(part, size); // Decodes as UTF-8 by default.
122       printAccum->parts.append(std::move(pyPart));
123     };
124   }
125 
126   pybind11::str join() {
127     pybind11::str delim("", 0);
128     return delim.attr("join")(parts);
129   }
130 };
131 
132 /// Accumulates int a python file-like object, either writing text (default)
133 /// or binary.
134 class PyFileAccumulator {
135 public:
136   PyFileAccumulator(pybind11::object fileObject, bool binary)
137       : pyWriteFunction(fileObject.attr("write")), binary(binary) {}
138 
139   void *getUserData() { return this; }
140 
141   MlirStringCallback getCallback() {
142     return [](const char *part, intptr_t size, void *userData) {
143       pybind11::gil_scoped_acquire();
144       PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
145       if (accum->binary) {
146         // Note: Still has to copy and not avoidable with this API.
147         pybind11::bytes pyBytes(part, size);
148         accum->pyWriteFunction(pyBytes);
149       } else {
150         pybind11::str pyStr(part, size); // Decodes as UTF-8 by default.
151         accum->pyWriteFunction(pyStr);
152       }
153     };
154   }
155 
156 private:
157   pybind11::object pyWriteFunction;
158   bool binary;
159 };
160 
161 /// Accumulates into a python string from a method that is expected to make
162 /// one (no more, no less) call to the callback (asserts internally on
163 /// violation).
164 struct PySinglePartStringAccumulator {
165   void *getUserData() { return this; }
166 
167   MlirStringCallback getCallback() {
168     return [](const char *part, intptr_t size, void *userData) {
169       PySinglePartStringAccumulator *accum =
170           static_cast<PySinglePartStringAccumulator *>(userData);
171       assert(!accum->invoked &&
172              "PySinglePartStringAccumulator called back multiple times");
173       accum->invoked = true;
174       accum->value = pybind11::str(part, size);
175     };
176   }
177 
178   pybind11::str takeValue() {
179     assert(invoked && "PySinglePartStringAccumulator not called back");
180     return std::move(value);
181   }
182 
183 private:
184   pybind11::str value;
185   bool invoked = false;
186 };
187 
188 /// A CRTP base class for pseudo-containers willing to support Python-type
189 /// slicing access on top of indexed access. Calling ::bind on this class
190 /// will define `__len__` as well as `__getitem__` with integer and slice
191 /// arguments.
192 ///
193 /// This is intended for pseudo-containers that can refer to arbitrary slices of
194 /// underlying storage indexed by a single integer. Indexing those with an
195 /// integer produces an instance of ElementTy. Indexing those with a slice
196 /// produces a new instance of Derived, which can be sliced further.
197 ///
198 /// A derived class must provide the following:
199 ///   - a `static const char *pyClassName ` field containing the name of the
200 ///     Python class to bind;
201 ///   - an instance method `intptr_t getNumElements()` that returns the number
202 ///     of elements in the backing container (NOT that of the slice);
203 ///   - an instance method `ElementTy getElement(intptr_t)` that returns a
204 ///     single element at the given index.
205 ///   - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
206 ///     constructs a new instance of the derived pseudo-container with the
207 ///     given slice parameters (to be forwarded to the Sliceable constructor).
208 ///
209 /// A derived class may additionally define:
210 ///   - a `static void bindDerived(ClassTy &)` method to bind additional methods
211 ///     the python class.
212 template <typename Derived, typename ElementTy>
213 class Sliceable {
214 protected:
215   using ClassTy = pybind11::class_<Derived>;
216 
217 public:
218   explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
219       : startIndex(startIndex), length(length), step(step) {
220     assert(length >= 0 && "expected non-negative slice length");
221   }
222 
223   /// Returns the length of the slice.
224   intptr_t dunderLen() const { return length; }
225 
226   /// Returns the element at the given slice index. Supports negative indices
227   /// by taking elements in inverse order. Throws if the index is out of bounds.
228   ElementTy dunderGetItem(intptr_t index) {
229     // Negative indices mean we count from the end.
230     if (index < 0)
231       index = length + index;
232     if (index < 0 || index >= length) {
233       throw python::SetPyError(PyExc_IndexError,
234                                "attempt to access out of bounds");
235     }
236 
237     // Compute the linear index given the current slice properties.
238     int linearIndex = index * step + startIndex;
239     assert(linearIndex >= 0 &&
240            linearIndex < static_cast<Derived *>(this)->getNumElements() &&
241            "linear index out of bounds, the slice is ill-formed");
242     return static_cast<Derived *>(this)->getElement(linearIndex);
243   }
244 
245   /// Returns a new instance of the pseudo-container restricted to the given
246   /// slice.
247   Derived dunderGetItemSlice(pybind11::slice slice) {
248     ssize_t start, stop, extraStep, sliceLength;
249     if (!slice.compute(dunderLen(), &start, &stop, &extraStep, &sliceLength)) {
250       throw python::SetPyError(PyExc_IndexError,
251                                "attempt to access out of bounds");
252     }
253     return static_cast<Derived *>(this)->slice(startIndex + start * step,
254                                                sliceLength, step * extraStep);
255   }
256 
257   /// Binds the indexing and length methods in the Python class.
258   static void bind(pybind11::module &m) {
259     auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName)
260                      .def("__len__", &Sliceable::dunderLen)
261                      .def("__getitem__", &Sliceable::dunderGetItem)
262                      .def("__getitem__", &Sliceable::dunderGetItemSlice);
263     Derived::bindDerived(clazz);
264   }
265 
266   /// Hook for derived classes willing to bind more methods.
267   static void bindDerived(ClassTy &) {}
268 
269 private:
270   intptr_t startIndex;
271   intptr_t length;
272   intptr_t step;
273 };
274 
275 } // namespace mlir
276 
277 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
278