195b77f2eSStella Laurenzo //===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
295b77f2eSStella Laurenzo //
395b77f2eSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
495b77f2eSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
595b77f2eSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
695b77f2eSStella Laurenzo //
795b77f2eSStella Laurenzo //===----------------------------------------------------------------------===//
895b77f2eSStella Laurenzo 
995b77f2eSStella Laurenzo #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
1095b77f2eSStella Laurenzo #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
1195b77f2eSStella Laurenzo 
125b29d26bSMehdi Amini #include "mlir-c/Support.h"
135b29d26bSMehdi Amini #include "llvm/ADT/Optional.h"
145b29d26bSMehdi Amini #include "llvm/ADT/Twine.h"
155b29d26bSMehdi Amini 
1695b77f2eSStella Laurenzo #include <pybind11/pybind11.h>
172d1362e0SStella Laurenzo #include <pybind11/stl.h>
1895b77f2eSStella Laurenzo 
1995b77f2eSStella Laurenzo namespace mlir {
2095b77f2eSStella Laurenzo namespace python {
2195b77f2eSStella Laurenzo 
2295b77f2eSStella Laurenzo // Sets a python error, ready to be thrown to return control back to the
2395b77f2eSStella Laurenzo // python runtime.
2495b77f2eSStella Laurenzo // Correct usage:
2595b77f2eSStella Laurenzo //   throw SetPyError(PyExc_ValueError, "Foobar'd");
263137c299SStella Laurenzo pybind11::error_already_set SetPyError(PyObject *excClass,
273137c299SStella Laurenzo                                        const llvm::Twine &message);
2895b77f2eSStella Laurenzo 
29af66cd17SStella Laurenzo /// CRTP template for special wrapper types that are allowed to be passed in as
30af66cd17SStella Laurenzo /// 'None' function arguments and can be resolved by some global mechanic if
31af66cd17SStella Laurenzo /// so. Such types will raise an error if this global resolution fails, and
32af66cd17SStella Laurenzo /// it is actually illegal for them to ever be unresolved. From a user
33af66cd17SStella Laurenzo /// perspective, they behave like a smart ptr to the underlying type (i.e.
34af66cd17SStella Laurenzo /// 'get' method and operator-> overloaded).
35af66cd17SStella Laurenzo ///
36af66cd17SStella Laurenzo /// Derived types must provide a method, which is called when an environmental
37af66cd17SStella Laurenzo /// resolution is required. It must raise an exception if resolution fails:
38af66cd17SStella Laurenzo ///   static ReferrentTy &resolve()
39af66cd17SStella Laurenzo ///
40af66cd17SStella Laurenzo /// They must also provide a parameter description that will be used in
41af66cd17SStella Laurenzo /// error messages about mismatched types:
42af66cd17SStella Laurenzo ///   static constexpr const char kTypeDescription[] = "<Description>";
43af66cd17SStella Laurenzo 
44af66cd17SStella Laurenzo template <typename DerivedTy, typename T>
45af66cd17SStella Laurenzo class Defaulting {
46af66cd17SStella Laurenzo public:
47af66cd17SStella Laurenzo   using ReferrentTy = T;
48af66cd17SStella Laurenzo   /// Type casters require the type to be default constructible, but using
49af66cd17SStella Laurenzo   /// such an instance is illegal.
50af66cd17SStella Laurenzo   Defaulting() = default;
Defaulting(ReferrentTy & referrent)51af66cd17SStella Laurenzo   Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
52af66cd17SStella Laurenzo 
get()53bd2083c2SStella Laurenzo   ReferrentTy *get() const { return referrent; }
54af66cd17SStella Laurenzo   ReferrentTy *operator->() { return referrent; }
55af66cd17SStella Laurenzo 
56af66cd17SStella Laurenzo private:
57af66cd17SStella Laurenzo   ReferrentTy *referrent = nullptr;
58af66cd17SStella Laurenzo };
59af66cd17SStella Laurenzo 
6095b77f2eSStella Laurenzo } // namespace python
6195b77f2eSStella Laurenzo } // namespace mlir
6295b77f2eSStella Laurenzo 
63af66cd17SStella Laurenzo namespace pybind11 {
64af66cd17SStella Laurenzo namespace detail {
65af66cd17SStella Laurenzo 
66af66cd17SStella Laurenzo template <typename DefaultingTy>
67af66cd17SStella Laurenzo struct MlirDefaultingCaster {
68af66cd17SStella Laurenzo   PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription));
69af66cd17SStella Laurenzo 
loadMlirDefaultingCaster70af66cd17SStella Laurenzo   bool load(pybind11::handle src, bool) {
71af66cd17SStella Laurenzo     if (src.is_none()) {
72af66cd17SStella Laurenzo       // Note that we do want an exception to propagate from here as it will be
73af66cd17SStella Laurenzo       // the most informative.
74af66cd17SStella Laurenzo       value = DefaultingTy{DefaultingTy::resolve()};
75af66cd17SStella Laurenzo       return true;
76af66cd17SStella Laurenzo     }
77af66cd17SStella Laurenzo 
78af66cd17SStella Laurenzo     // Unlike many casters that chain, these casters are expected to always
79af66cd17SStella Laurenzo     // succeed, so instead of doing an isinstance check followed by a cast,
80af66cd17SStella Laurenzo     // just cast in one step and handle the exception. Returning false (vs
81af66cd17SStella Laurenzo     // letting the exception propagate) causes higher level signature parsing
82af66cd17SStella Laurenzo     // code to produce nice error messages (other than "Cannot cast...").
83af66cd17SStella Laurenzo     try {
84af66cd17SStella Laurenzo       value = DefaultingTy{
85af66cd17SStella Laurenzo           pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)};
86af66cd17SStella Laurenzo       return true;
874726a402SStella Laurenzo     } catch (std::exception &) {
88af66cd17SStella Laurenzo       return false;
89af66cd17SStella Laurenzo     }
90af66cd17SStella Laurenzo   }
91af66cd17SStella Laurenzo 
castMlirDefaultingCaster92af66cd17SStella Laurenzo   static handle cast(DefaultingTy src, return_value_policy policy,
93af66cd17SStella Laurenzo                      handle parent) {
94af66cd17SStella Laurenzo     return pybind11::cast(src, policy);
95af66cd17SStella Laurenzo   }
96af66cd17SStella Laurenzo };
97af66cd17SStella Laurenzo 
98af66cd17SStella Laurenzo template <typename T>
99af66cd17SStella Laurenzo struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
100af66cd17SStella Laurenzo } // namespace detail
101af66cd17SStella Laurenzo } // namespace pybind11
102af66cd17SStella Laurenzo 
1035b29d26bSMehdi Amini //------------------------------------------------------------------------------
1045b29d26bSMehdi Amini // Conversion utilities.
1055b29d26bSMehdi Amini //------------------------------------------------------------------------------
1065b29d26bSMehdi Amini 
1075b29d26bSMehdi Amini namespace mlir {
1085b29d26bSMehdi Amini 
1095b29d26bSMehdi Amini /// Accumulates into a python string from a method that accepts an
1105b29d26bSMehdi Amini /// MlirStringCallback.
1115b29d26bSMehdi Amini struct PyPrintAccumulator {
1125b29d26bSMehdi Amini   pybind11::list parts;
1135b29d26bSMehdi Amini 
1145b29d26bSMehdi Amini   void *getUserData() { return this; }
1155b29d26bSMehdi Amini 
1165b29d26bSMehdi Amini   MlirStringCallback getCallback() {
1175f0c1e38Szhanghb97     return [](MlirStringRef part, void *userData) {
1185b29d26bSMehdi Amini       PyPrintAccumulator *printAccum =
1195b29d26bSMehdi Amini           static_cast<PyPrintAccumulator *>(userData);
1205f0c1e38Szhanghb97       pybind11::str pyPart(part.data,
1215f0c1e38Szhanghb97                            part.length); // Decodes as UTF-8 by default.
1225b29d26bSMehdi Amini       printAccum->parts.append(std::move(pyPart));
1235b29d26bSMehdi Amini     };
1245b29d26bSMehdi Amini   }
1255b29d26bSMehdi Amini 
1265b29d26bSMehdi Amini   pybind11::str join() {
1275b29d26bSMehdi Amini     pybind11::str delim("", 0);
1285b29d26bSMehdi Amini     return delim.attr("join")(parts);
1295b29d26bSMehdi Amini   }
1305b29d26bSMehdi Amini };
1315b29d26bSMehdi Amini 
1325b29d26bSMehdi Amini /// Accumulates int a python file-like object, either writing text (default)
1335b29d26bSMehdi Amini /// or binary.
1345b29d26bSMehdi Amini class PyFileAccumulator {
1355b29d26bSMehdi Amini public:
136e8d07395SMehdi Amini   PyFileAccumulator(const pybind11::object &fileObject, bool binary)
1375b29d26bSMehdi Amini       : pyWriteFunction(fileObject.attr("write")), binary(binary) {}
1385b29d26bSMehdi Amini 
1395b29d26bSMehdi Amini   void *getUserData() { return this; }
1405b29d26bSMehdi Amini 
1415b29d26bSMehdi Amini   MlirStringCallback getCallback() {
1425f0c1e38Szhanghb97     return [](MlirStringRef part, void *userData) {
143babad7c5SAdrian Kuegel       pybind11::gil_scoped_acquire acquire;
1445b29d26bSMehdi Amini       PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
1455b29d26bSMehdi Amini       if (accum->binary) {
1465b29d26bSMehdi Amini         // Note: Still has to copy and not avoidable with this API.
1475f0c1e38Szhanghb97         pybind11::bytes pyBytes(part.data, part.length);
1485b29d26bSMehdi Amini         accum->pyWriteFunction(pyBytes);
1495b29d26bSMehdi Amini       } else {
1505f0c1e38Szhanghb97         pybind11::str pyStr(part.data,
1515f0c1e38Szhanghb97                             part.length); // Decodes as UTF-8 by default.
1525b29d26bSMehdi Amini         accum->pyWriteFunction(pyStr);
1535b29d26bSMehdi Amini       }
1545b29d26bSMehdi Amini     };
1555b29d26bSMehdi Amini   }
1565b29d26bSMehdi Amini 
1575b29d26bSMehdi Amini private:
1585b29d26bSMehdi Amini   pybind11::object pyWriteFunction;
1595b29d26bSMehdi Amini   bool binary;
1605b29d26bSMehdi Amini };
1615b29d26bSMehdi Amini 
1625b29d26bSMehdi Amini /// Accumulates into a python string from a method that is expected to make
1635b29d26bSMehdi Amini /// one (no more, no less) call to the callback (asserts internally on
1645b29d26bSMehdi Amini /// violation).
1655b29d26bSMehdi Amini struct PySinglePartStringAccumulator {
1665b29d26bSMehdi Amini   void *getUserData() { return this; }
1675b29d26bSMehdi Amini 
1685b29d26bSMehdi Amini   MlirStringCallback getCallback() {
1695f0c1e38Szhanghb97     return [](MlirStringRef part, void *userData) {
1705b29d26bSMehdi Amini       PySinglePartStringAccumulator *accum =
1715b29d26bSMehdi Amini           static_cast<PySinglePartStringAccumulator *>(userData);
1725b29d26bSMehdi Amini       assert(!accum->invoked &&
1735b29d26bSMehdi Amini              "PySinglePartStringAccumulator called back multiple times");
1745b29d26bSMehdi Amini       accum->invoked = true;
1755f0c1e38Szhanghb97       accum->value = pybind11::str(part.data, part.length);
1765b29d26bSMehdi Amini     };
1775b29d26bSMehdi Amini   }
1785b29d26bSMehdi Amini 
1795b29d26bSMehdi Amini   pybind11::str takeValue() {
1805b29d26bSMehdi Amini     assert(invoked && "PySinglePartStringAccumulator not called back");
1815b29d26bSMehdi Amini     return std::move(value);
1825b29d26bSMehdi Amini   }
1835b29d26bSMehdi Amini 
1845b29d26bSMehdi Amini private:
1855b29d26bSMehdi Amini   pybind11::str value;
1865b29d26bSMehdi Amini   bool invoked = false;
1875b29d26bSMehdi Amini };
1885b29d26bSMehdi Amini 
1896c7e6b2cSAlex Zinenko /// A CRTP base class for pseudo-containers willing to support Python-type
1906c7e6b2cSAlex Zinenko /// slicing access on top of indexed access. Calling ::bind on this class
1916c7e6b2cSAlex Zinenko /// will define `__len__` as well as `__getitem__` with integer and slice
1926c7e6b2cSAlex Zinenko /// arguments.
1936c7e6b2cSAlex Zinenko ///
1946c7e6b2cSAlex Zinenko /// This is intended for pseudo-containers that can refer to arbitrary slices of
1956c7e6b2cSAlex Zinenko /// underlying storage indexed by a single integer. Indexing those with an
1966c7e6b2cSAlex Zinenko /// integer produces an instance of ElementTy. Indexing those with a slice
1976c7e6b2cSAlex Zinenko /// produces a new instance of Derived, which can be sliced further.
1986c7e6b2cSAlex Zinenko ///
1996c7e6b2cSAlex Zinenko /// A derived class must provide the following:
2006c7e6b2cSAlex Zinenko ///   - a `static const char *pyClassName ` field containing the name of the
2016c7e6b2cSAlex Zinenko ///     Python class to bind;
202*ee168fb9SAlex Zinenko ///   - an instance method `intptr_t getRawNumElements()` that returns the
203*ee168fb9SAlex Zinenko ///   number
2046c7e6b2cSAlex Zinenko ///     of elements in the backing container (NOT that of the slice);
205*ee168fb9SAlex Zinenko ///   - an instance method `ElementTy getRawElement(intptr_t)` that returns a
206*ee168fb9SAlex Zinenko ///     single element at the given linear index (NOT slice index);
2076c7e6b2cSAlex Zinenko ///   - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
2086c7e6b2cSAlex Zinenko ///     constructs a new instance of the derived pseudo-container with the
2096c7e6b2cSAlex Zinenko ///     given slice parameters (to be forwarded to the Sliceable constructor).
2106c7e6b2cSAlex Zinenko ///
211*ee168fb9SAlex Zinenko /// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
212*ee168fb9SAlex Zinenko /// throw.
213429b0cf1SStella Laurenzo ///
2146c7e6b2cSAlex Zinenko /// A derived class may additionally define:
2156c7e6b2cSAlex Zinenko ///   - a `static void bindDerived(ClassTy &)` method to bind additional methods
2166c7e6b2cSAlex Zinenko ///     the python class.
2176c7e6b2cSAlex Zinenko template <typename Derived, typename ElementTy>
2186c7e6b2cSAlex Zinenko class Sliceable {
2196c7e6b2cSAlex Zinenko protected:
2206c7e6b2cSAlex Zinenko   using ClassTy = pybind11::class_<Derived>;
2216c7e6b2cSAlex Zinenko 
222*ee168fb9SAlex Zinenko   /// Transforms `index` into a legal value to access the underlying sequence.
223*ee168fb9SAlex Zinenko   /// Returns <0 on failure.
22463d16d06SMike Urbach   intptr_t wrapIndex(intptr_t index) {
22563d16d06SMike Urbach     if (index < 0)
22663d16d06SMike Urbach       index = length + index;
227429b0cf1SStella Laurenzo     if (index < 0 || index >= length)
228429b0cf1SStella Laurenzo       return -1;
22963d16d06SMike Urbach     return index;
23063d16d06SMike Urbach   }
23163d16d06SMike Urbach 
232*ee168fb9SAlex Zinenko   /// Computes the linear index given the current slice properties.
233*ee168fb9SAlex Zinenko   intptr_t linearizeIndex(intptr_t index) {
234*ee168fb9SAlex Zinenko     intptr_t linearIndex = index * step + startIndex;
235*ee168fb9SAlex Zinenko     assert(linearIndex >= 0 &&
236*ee168fb9SAlex Zinenko            linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
237*ee168fb9SAlex Zinenko            "linear index out of bounds, the slice is ill-formed");
238*ee168fb9SAlex Zinenko     return linearIndex;
239*ee168fb9SAlex Zinenko   }
240*ee168fb9SAlex Zinenko 
2416c7e6b2cSAlex Zinenko   /// Returns the element at the given slice index. Supports negative indices
242429b0cf1SStella Laurenzo   /// by taking elements in inverse order. Returns a nullptr object if out
243429b0cf1SStella Laurenzo   /// of bounds.
244429b0cf1SStella Laurenzo   pybind11::object getItem(intptr_t index) {
2456c7e6b2cSAlex Zinenko     // Negative indices mean we count from the end.
24663d16d06SMike Urbach     index = wrapIndex(index);
247429b0cf1SStella Laurenzo     if (index < 0) {
248429b0cf1SStella Laurenzo       PyErr_SetString(PyExc_IndexError, "index out of range");
249429b0cf1SStella Laurenzo       return {};
250429b0cf1SStella Laurenzo     }
2516c7e6b2cSAlex Zinenko 
252429b0cf1SStella Laurenzo     return pybind11::cast(
253*ee168fb9SAlex Zinenko         static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
2546c7e6b2cSAlex Zinenko   }
2556c7e6b2cSAlex Zinenko 
2566c7e6b2cSAlex Zinenko   /// Returns a new instance of the pseudo-container restricted to the given
257429b0cf1SStella Laurenzo   /// slice. Returns a nullptr object on failure.
258429b0cf1SStella Laurenzo   pybind11::object getItemSlice(PyObject *slice) {
2596c7e6b2cSAlex Zinenko     ssize_t start, stop, extraStep, sliceLength;
260429b0cf1SStella Laurenzo     if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
261429b0cf1SStella Laurenzo                              &sliceLength) != 0) {
262429b0cf1SStella Laurenzo       PyErr_SetString(PyExc_IndexError, "index out of range");
263429b0cf1SStella Laurenzo       return {};
2646c7e6b2cSAlex Zinenko     }
265429b0cf1SStella Laurenzo     return pybind11::cast(static_cast<Derived *>(this)->slice(
266429b0cf1SStella Laurenzo         startIndex + start * step, sliceLength, step * extraStep));
267429b0cf1SStella Laurenzo   }
268429b0cf1SStella Laurenzo 
269429b0cf1SStella Laurenzo public:
270429b0cf1SStella Laurenzo   explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
271429b0cf1SStella Laurenzo       : startIndex(startIndex), length(length), step(step) {
272429b0cf1SStella Laurenzo     assert(length >= 0 && "expected non-negative slice length");
2736c7e6b2cSAlex Zinenko   }
2746c7e6b2cSAlex Zinenko 
275*ee168fb9SAlex Zinenko   /// Returns the `index`-th element in the slice, supports negative indices.
276*ee168fb9SAlex Zinenko   /// Throws if the index is out of bounds.
277*ee168fb9SAlex Zinenko   ElementTy getElement(intptr_t index) {
278*ee168fb9SAlex Zinenko     // Negative indices mean we count from the end.
279*ee168fb9SAlex Zinenko     index = wrapIndex(index);
280*ee168fb9SAlex Zinenko     if (index < 0) {
281*ee168fb9SAlex Zinenko       throw pybind11::index_error("index out of range");
282*ee168fb9SAlex Zinenko     }
283*ee168fb9SAlex Zinenko 
284*ee168fb9SAlex Zinenko     return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
285*ee168fb9SAlex Zinenko   }
286*ee168fb9SAlex Zinenko 
287*ee168fb9SAlex Zinenko   /// Returns the size of slice.
288*ee168fb9SAlex Zinenko   intptr_t size() { return length; }
289*ee168fb9SAlex Zinenko 
290afeda4b9SAlex Zinenko   /// Returns a new vector (mapped to Python list) containing elements from two
291afeda4b9SAlex Zinenko   /// slices. The new vector is necessary because slices may not be contiguous
292afeda4b9SAlex Zinenko   /// or even come from the same original sequence.
293afeda4b9SAlex Zinenko   std::vector<ElementTy> dunderAdd(Derived &other) {
294afeda4b9SAlex Zinenko     std::vector<ElementTy> elements;
295afeda4b9SAlex Zinenko     elements.reserve(length + other.length);
296afeda4b9SAlex Zinenko     for (intptr_t i = 0; i < length; ++i) {
297429b0cf1SStella Laurenzo       elements.push_back(static_cast<Derived *>(this)->getElement(i));
298afeda4b9SAlex Zinenko     }
299afeda4b9SAlex Zinenko     for (intptr_t i = 0; i < other.length; ++i) {
300*ee168fb9SAlex Zinenko       elements.push_back(static_cast<Derived *>(&other)->getElement(i));
301afeda4b9SAlex Zinenko     }
302afeda4b9SAlex Zinenko     return elements;
303afeda4b9SAlex Zinenko   }
304afeda4b9SAlex Zinenko 
3056c7e6b2cSAlex Zinenko   /// Binds the indexing and length methods in the Python class.
3066c7e6b2cSAlex Zinenko   static void bind(pybind11::module &m) {
3078dca953dSSean Silva     auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
3088dca953dSSean Silva                                            pybind11::module_local())
309afeda4b9SAlex Zinenko                      .def("__add__", &Sliceable::dunderAdd);
3106c7e6b2cSAlex Zinenko     Derived::bindDerived(clazz);
311429b0cf1SStella Laurenzo 
312429b0cf1SStella Laurenzo     // Manually implement the sequence protocol via the C API. We do this
313429b0cf1SStella Laurenzo     // because it is approx 4x faster than via pybind11, largely because that
314429b0cf1SStella Laurenzo     // formulation requires a C++ exception to be thrown to detect end of
315429b0cf1SStella Laurenzo     // sequence.
316429b0cf1SStella Laurenzo     // Since we are in a C-context, any C++ exception that happens here
317429b0cf1SStella Laurenzo     // will terminate the program. There is nothing in this implementation
318429b0cf1SStella Laurenzo     // that should throw in a non-terminal way, so we forgo further
319429b0cf1SStella Laurenzo     // exception marshalling.
320429b0cf1SStella Laurenzo     // See: https://github.com/pybind/pybind11/issues/2842
321429b0cf1SStella Laurenzo     auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
322429b0cf1SStella Laurenzo     assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
323429b0cf1SStella Laurenzo            "must be heap type");
324429b0cf1SStella Laurenzo     heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
325429b0cf1SStella Laurenzo       auto self = pybind11::cast<Derived *>(rawSelf);
326429b0cf1SStella Laurenzo       return self->length;
327429b0cf1SStella Laurenzo     };
328429b0cf1SStella Laurenzo     // sq_item is called as part of the sequence protocol for iteration,
329429b0cf1SStella Laurenzo     // list construction, etc.
330429b0cf1SStella Laurenzo     heap_type->as_sequence.sq_item =
331429b0cf1SStella Laurenzo         +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
332429b0cf1SStella Laurenzo       auto self = pybind11::cast<Derived *>(rawSelf);
333429b0cf1SStella Laurenzo       return self->getItem(index).release().ptr();
334429b0cf1SStella Laurenzo     };
335429b0cf1SStella Laurenzo     // mp_subscript is used for both slices and integer lookups.
336429b0cf1SStella Laurenzo     heap_type->as_mapping.mp_subscript =
337429b0cf1SStella Laurenzo         +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
338429b0cf1SStella Laurenzo       auto self = pybind11::cast<Derived *>(rawSelf);
339429b0cf1SStella Laurenzo       Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
340429b0cf1SStella Laurenzo       if (!PyErr_Occurred()) {
341429b0cf1SStella Laurenzo         // Integer indexing.
342429b0cf1SStella Laurenzo         return self->getItem(index).release().ptr();
343429b0cf1SStella Laurenzo       }
344429b0cf1SStella Laurenzo       PyErr_Clear();
345429b0cf1SStella Laurenzo 
346429b0cf1SStella Laurenzo       // Assume slice-based indexing.
347429b0cf1SStella Laurenzo       if (PySlice_Check(rawSubscript)) {
348429b0cf1SStella Laurenzo         return self->getItemSlice(rawSubscript).release().ptr();
349429b0cf1SStella Laurenzo       }
350429b0cf1SStella Laurenzo 
351429b0cf1SStella Laurenzo       PyErr_SetString(PyExc_ValueError, "expected integer or slice");
352429b0cf1SStella Laurenzo       return nullptr;
353429b0cf1SStella Laurenzo     };
3546c7e6b2cSAlex Zinenko   }
3556c7e6b2cSAlex Zinenko 
3566c7e6b2cSAlex Zinenko   /// Hook for derived classes willing to bind more methods.
3576c7e6b2cSAlex Zinenko   static void bindDerived(ClassTy &) {}
3586c7e6b2cSAlex Zinenko 
3596c7e6b2cSAlex Zinenko private:
3606c7e6b2cSAlex Zinenko   intptr_t startIndex;
3616c7e6b2cSAlex Zinenko   intptr_t length;
3626c7e6b2cSAlex Zinenko   intptr_t step;
3636c7e6b2cSAlex Zinenko };
3646c7e6b2cSAlex Zinenko 
3655b29d26bSMehdi Amini } // namespace mlir
3665b29d26bSMehdi Amini 
36795b77f2eSStella Laurenzo #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
368